1//===- OpenMPToLLVMIRTranslation.cpp - Translate OpenMP dialect to LLVM IR-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a translation between the MLIR OpenMP dialect and LLVM
10// IR.
11//
12//===----------------------------------------------------------------------===//
13#include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h"
14#include "mlir/Analysis/TopologicalSortUtils.h"
15#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
17#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
18#include "mlir/Dialect/OpenMP/OpenMPInterfaces.h"
19#include "mlir/IR/IRMapping.h"
20#include "mlir/IR/Operation.h"
21#include "mlir/Support/LLVM.h"
22#include "mlir/Target/LLVMIR/Dialect/OpenMPCommon.h"
23#include "mlir/Target/LLVMIR/ModuleTranslation.h"
24#include "mlir/Transforms/RegionUtils.h"
25
26#include "llvm/ADT/ArrayRef.h"
27#include "llvm/ADT/SetVector.h"
28#include "llvm/ADT/SmallVector.h"
29#include "llvm/ADT/TypeSwitch.h"
30#include "llvm/Frontend/OpenMP/OMPConstants.h"
31#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
32#include "llvm/IR/Constants.h"
33#include "llvm/IR/DebugInfoMetadata.h"
34#include "llvm/IR/DerivedTypes.h"
35#include "llvm/IR/IRBuilder.h"
36#include "llvm/IR/MDBuilder.h"
37#include "llvm/IR/ReplaceConstant.h"
38#include "llvm/Support/FileSystem.h"
39#include "llvm/TargetParser/Triple.h"
40#include "llvm/Transforms/Utils/ModuleUtils.h"
41
42#include <any>
43#include <cstdint>
44#include <iterator>
45#include <numeric>
46#include <optional>
47#include <utility>
48
49using namespace mlir;
50
51namespace {
52static llvm::omp::ScheduleKind
53convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) {
54 if (!schedKind.has_value())
55 return llvm::omp::OMP_SCHEDULE_Default;
56 switch (schedKind.value()) {
57 case omp::ClauseScheduleKind::Static:
58 return llvm::omp::OMP_SCHEDULE_Static;
59 case omp::ClauseScheduleKind::Dynamic:
60 return llvm::omp::OMP_SCHEDULE_Dynamic;
61 case omp::ClauseScheduleKind::Guided:
62 return llvm::omp::OMP_SCHEDULE_Guided;
63 case omp::ClauseScheduleKind::Auto:
64 return llvm::omp::OMP_SCHEDULE_Auto;
65 case omp::ClauseScheduleKind::Runtime:
66 return llvm::omp::OMP_SCHEDULE_Runtime;
67 }
68 llvm_unreachable("unhandled schedule clause argument");
69}
70
71/// ModuleTranslation stack frame for OpenMP operations. This keeps track of the
72/// insertion points for allocas.
73class OpenMPAllocaStackFrame
74 : public LLVM::ModuleTranslation::StackFrameBase<OpenMPAllocaStackFrame> {
75public:
76 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPAllocaStackFrame)
77
78 explicit OpenMPAllocaStackFrame(llvm::OpenMPIRBuilder::InsertPointTy allocaIP)
79 : allocaInsertPoint(allocaIP) {}
80 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
81};
82
83/// Stack frame to hold a \see llvm::CanonicalLoopInfo representing the
84/// collapsed canonical loop information corresponding to an \c omp.loop_nest
85/// operation.
86class OpenMPLoopInfoStackFrame
87 : public LLVM::ModuleTranslation::StackFrameBase<OpenMPLoopInfoStackFrame> {
88public:
89 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPLoopInfoStackFrame)
90 llvm::CanonicalLoopInfo *loopInfo = nullptr;
91};
92
93/// Custom error class to signal translation errors that don't need reporting,
94/// since encountering them will have already triggered relevant error messages.
95///
96/// Its purpose is to serve as the glue between MLIR failures represented as
97/// \see LogicalResult instances and \see llvm::Error instances used to
98/// propagate errors through the \see llvm::OpenMPIRBuilder. Generally, when an
99/// error of the first type is raised, a message is emitted directly (the \see
100/// LogicalResult itself does not hold any information). If we need to forward
101/// this error condition as an \see llvm::Error while avoiding triggering some
102/// redundant error reporting later on, we need a custom \see llvm::ErrorInfo
103/// class to just signal this situation has happened.
104///
105/// For example, this class should be used to trigger errors from within
106/// callbacks passed to the \see OpenMPIRBuilder when they were triggered by the
107/// translation of their own regions. This unclutters the error log from
108/// redundant messages.
109class PreviouslyReportedError
110 : public llvm::ErrorInfo<PreviouslyReportedError> {
111public:
112 void log(raw_ostream &) const override {
113 // Do not log anything.
114 }
115
116 std::error_code convertToErrorCode() const override {
117 llvm_unreachable(
118 "PreviouslyReportedError doesn't support ECError conversion");
119 }
120
121 // Used by ErrorInfo::classID.
122 static char ID;
123};
124
125char PreviouslyReportedError::ID = 0;
126
127/*
128 * Custom class for processing linear clause for omp.wsloop
129 * and omp.simd. Linear clause translation requires setup,
130 * initialization, update, and finalization at varying
131 * basic blocks in the IR. This class helps maintain
132 * internal state to allow consistent translation in
133 * each of these stages.
134 */
135
136class LinearClauseProcessor {
137
138private:
139 SmallVector<llvm::Value *> linearPreconditionVars;
140 SmallVector<llvm::Value *> linearLoopBodyTemps;
141 SmallVector<llvm::AllocaInst *> linearOrigVars;
142 SmallVector<llvm::Value *> linearOrigVal;
143 SmallVector<llvm::Value *> linearSteps;
144 llvm::BasicBlock *linearFinalizationBB;
145 llvm::BasicBlock *linearExitBB;
146 llvm::BasicBlock *linearLastIterExitBB;
147
148public:
149 // Allocate space for linear variabes
150 void createLinearVar(llvm::IRBuilderBase &builder,
151 LLVM::ModuleTranslation &moduleTranslation,
152 mlir::Value &linearVar) {
153 if (llvm::AllocaInst *linearVarAlloca = dyn_cast<llvm::AllocaInst>(
154 Val: moduleTranslation.lookupValue(value: linearVar))) {
155 linearPreconditionVars.push_back(Elt: builder.CreateAlloca(
156 Ty: linearVarAlloca->getAllocatedType(), ArraySize: nullptr, Name: ".linear_var"));
157 llvm::Value *linearLoopBodyTemp = builder.CreateAlloca(
158 Ty: linearVarAlloca->getAllocatedType(), ArraySize: nullptr, Name: ".linear_result");
159 linearOrigVal.push_back(Elt: moduleTranslation.lookupValue(value: linearVar));
160 linearLoopBodyTemps.push_back(Elt: linearLoopBodyTemp);
161 linearOrigVars.push_back(Elt: linearVarAlloca);
162 }
163 }
164
165 // Initialize linear step
166 inline void initLinearStep(LLVM::ModuleTranslation &moduleTranslation,
167 mlir::Value &linearStep) {
168 linearSteps.push_back(Elt: moduleTranslation.lookupValue(value: linearStep));
169 }
170
171 // Emit IR for initialization of linear variables
172 llvm::OpenMPIRBuilder::InsertPointOrErrorTy
173 initLinearVar(llvm::IRBuilderBase &builder,
174 LLVM::ModuleTranslation &moduleTranslation,
175 llvm::BasicBlock *loopPreHeader) {
176 builder.SetInsertPoint(loopPreHeader->getTerminator());
177 for (size_t index = 0; index < linearOrigVars.size(); index++) {
178 llvm::LoadInst *linearVarLoad = builder.CreateLoad(
179 Ty: linearOrigVars[index]->getAllocatedType(), Ptr: linearOrigVars[index]);
180 builder.CreateStore(Val: linearVarLoad, Ptr: linearPreconditionVars[index]);
181 }
182 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
183 moduleTranslation.getOpenMPBuilder()->createBarrier(
184 builder.saveIP(), llvm::omp::OMPD_barrier);
185 return afterBarrierIP;
186 }
187
188 // Emit IR for updating Linear variables
189 void updateLinearVar(llvm::IRBuilderBase &builder, llvm::BasicBlock *loopBody,
190 llvm::Value *loopInductionVar) {
191 builder.SetInsertPoint(loopBody->getTerminator());
192 for (size_t index = 0; index < linearPreconditionVars.size(); index++) {
193 // Emit increments for linear vars
194 llvm::LoadInst *linearVarStart =
195 builder.CreateLoad(Ty: linearOrigVars[index]->getAllocatedType(),
196
197 Ptr: linearPreconditionVars[index]);
198 auto mulInst = builder.CreateMul(LHS: loopInductionVar, RHS: linearSteps[index]);
199 auto addInst = builder.CreateAdd(LHS: linearVarStart, RHS: mulInst);
200 builder.CreateStore(Val: addInst, Ptr: linearLoopBodyTemps[index]);
201 }
202 }
203
204 // Linear variable finalization is conditional on the last logical iteration.
205 // Create BB splits to manage the same.
206 void outlineLinearFinalizationBB(llvm::IRBuilderBase &builder,
207 llvm::BasicBlock *loopExit) {
208 linearFinalizationBB = loopExit->splitBasicBlock(
209 I: loopExit->getTerminator(), BBName: "omp_loop.linear_finalization");
210 linearExitBB = linearFinalizationBB->splitBasicBlock(
211 I: linearFinalizationBB->getTerminator(), BBName: "omp_loop.linear_exit");
212 linearLastIterExitBB = linearFinalizationBB->splitBasicBlock(
213 I: linearFinalizationBB->getTerminator(), BBName: "omp_loop.linear_lastiter_exit");
214 }
215
216 // Finalize the linear vars
217 llvm::OpenMPIRBuilder::InsertPointOrErrorTy
218 finalizeLinearVar(llvm::IRBuilderBase &builder,
219 LLVM::ModuleTranslation &moduleTranslation,
220 llvm::Value *lastIter) {
221 // Emit condition to check whether last logical iteration is being executed
222 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
223 llvm::Value *loopLastIterLoad = builder.CreateLoad(
224 Ty: llvm::Type::getInt32Ty(C&: builder.getContext()), Ptr: lastIter);
225 llvm::Value *isLast =
226 builder.CreateCmp(Pred: llvm::CmpInst::ICMP_NE, LHS: loopLastIterLoad,
227 RHS: llvm::ConstantInt::get(
228 Ty: llvm::Type::getInt32Ty(C&: builder.getContext()), V: 0));
229 // Store the linear variable values to original variables.
230 builder.SetInsertPoint(linearLastIterExitBB->getTerminator());
231 for (size_t index = 0; index < linearOrigVars.size(); index++) {
232 llvm::LoadInst *linearVarTemp =
233 builder.CreateLoad(Ty: linearOrigVars[index]->getAllocatedType(),
234 Ptr: linearLoopBodyTemps[index]);
235 builder.CreateStore(Val: linearVarTemp, Ptr: linearOrigVars[index]);
236 }
237
238 // Create conditional branch such that the linear variable
239 // values are stored to original variables only at the
240 // last logical iteration
241 builder.SetInsertPoint(linearFinalizationBB->getTerminator());
242 builder.CreateCondBr(Cond: isLast, True: linearLastIterExitBB, False: linearExitBB);
243 linearFinalizationBB->getTerminator()->eraseFromParent();
244 // Emit barrier
245 builder.SetInsertPoint(linearExitBB->getTerminator());
246 return moduleTranslation.getOpenMPBuilder()->createBarrier(
247 builder.saveIP(), llvm::omp::OMPD_barrier);
248 }
249
250 // Rewrite all uses of the original variable in `BBName`
251 // with the linear variable in-place
252 void rewriteInPlace(llvm::IRBuilderBase &builder, std::string BBName,
253 size_t varIndex) {
254 llvm::SmallVector<llvm::User *> users;
255 for (llvm::User *user : linearOrigVal[varIndex]->users())
256 users.push_back(Elt: user);
257 for (auto *user : users) {
258 if (auto *userInst = dyn_cast<llvm::Instruction>(Val: user)) {
259 if (userInst->getParent()->getName().str() == BBName)
260 user->replaceUsesOfWith(From: linearOrigVal[varIndex],
261 To: linearLoopBodyTemps[varIndex]);
262 }
263 }
264 }
265};
266
267} // namespace
268
269/// Looks up from the operation from and returns the PrivateClauseOp with
270/// name symbolName
271static omp::PrivateClauseOp findPrivatizer(Operation *from,
272 SymbolRefAttr symbolName) {
273 omp::PrivateClauseOp privatizer =
274 SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(from,
275 symbolName);
276 assert(privatizer && "privatizer not found in the symbol table");
277 return privatizer;
278}
279
280/// Check whether translation to LLVM IR for the given operation is currently
281/// supported. If not, descriptive diagnostics will be emitted to let users know
282/// this is a not-yet-implemented feature.
283///
284/// \returns success if no unimplemented features are needed to translate the
285/// given operation.
286static LogicalResult checkImplementationStatus(Operation &op) {
287 auto todo = [&op](StringRef clauseName) {
288 return op.emitError() << "not yet implemented: Unhandled clause "
289 << clauseName << " in " << op.getName()
290 << " operation";
291 };
292
293 auto checkAllocate = [&todo](auto op, LogicalResult &result) {
294 if (!op.getAllocateVars().empty() || !op.getAllocatorVars().empty())
295 result = todo("allocate");
296 };
297 auto checkBare = [&todo](auto op, LogicalResult &result) {
298 if (op.getBare())
299 result = todo("ompx_bare");
300 };
301 auto checkCancelDirective = [&todo](auto op, LogicalResult &result) {
302 omp::ClauseCancellationConstructType cancelledDirective =
303 op.getCancelDirective();
304 // Cancelling a taskloop is not yet supported because we don't yet have LLVM
305 // IR conversion for taskloop
306 if (cancelledDirective == omp::ClauseCancellationConstructType::Taskgroup) {
307 Operation *parent = op->getParentOp();
308 while (parent) {
309 if (parent->getDialect() == op->getDialect())
310 break;
311 parent = parent->getParentOp();
312 }
313 if (isa_and_nonnull<omp::TaskloopOp>(parent))
314 result = todo("cancel directive inside of taskloop");
315 }
316 };
317 auto checkDepend = [&todo](auto op, LogicalResult &result) {
318 if (!op.getDependVars().empty() || op.getDependKinds())
319 result = todo("depend");
320 };
321 auto checkDevice = [&todo](auto op, LogicalResult &result) {
322 if (op.getDevice())
323 result = todo("device");
324 };
325 auto checkDistSchedule = [&todo](auto op, LogicalResult &result) {
326 if (op.getDistScheduleChunkSize())
327 result = todo("dist_schedule with chunk_size");
328 };
329 auto checkHint = [](auto op, LogicalResult &) {
330 if (op.getHint())
331 op.emitWarning("hint clause discarded");
332 };
333 auto checkInReduction = [&todo](auto op, LogicalResult &result) {
334 if (!op.getInReductionVars().empty() || op.getInReductionByref() ||
335 op.getInReductionSyms())
336 result = todo("in_reduction");
337 };
338 auto checkIsDevicePtr = [&todo](auto op, LogicalResult &result) {
339 if (!op.getIsDevicePtrVars().empty())
340 result = todo("is_device_ptr");
341 };
342 auto checkLinear = [&todo](auto op, LogicalResult &result) {
343 if (!op.getLinearVars().empty() || !op.getLinearStepVars().empty())
344 result = todo("linear");
345 };
346 auto checkNowait = [&todo](auto op, LogicalResult &result) {
347 if (op.getNowait())
348 result = todo("nowait");
349 };
350 auto checkOrder = [&todo](auto op, LogicalResult &result) {
351 if (op.getOrder() || op.getOrderMod())
352 result = todo("order");
353 };
354 auto checkParLevelSimd = [&todo](auto op, LogicalResult &result) {
355 if (op.getParLevelSimd())
356 result = todo("parallelization-level");
357 };
358 auto checkPriority = [&todo](auto op, LogicalResult &result) {
359 if (op.getPriority())
360 result = todo("priority");
361 };
362 auto checkPrivate = [&todo](auto op, LogicalResult &result) {
363 if constexpr (std::is_same_v<std::decay_t<decltype(op)>, omp::TargetOp>) {
364 // Privatization is supported only for included target tasks.
365 if (!op.getPrivateVars().empty() && op.getNowait())
366 result = todo("privatization for deferred target tasks");
367 } else {
368 if (!op.getPrivateVars().empty() || op.getPrivateSyms())
369 result = todo("privatization");
370 }
371 };
372 auto checkReduction = [&todo](auto op, LogicalResult &result) {
373 if (isa<omp::TeamsOp>(op) || isa<omp::SimdOp>(op))
374 if (!op.getReductionVars().empty() || op.getReductionByref() ||
375 op.getReductionSyms())
376 result = todo("reduction");
377 if (op.getReductionMod() &&
378 op.getReductionMod().value() != omp::ReductionModifier::defaultmod)
379 result = todo("reduction with modifier");
380 };
381 auto checkTaskReduction = [&todo](auto op, LogicalResult &result) {
382 if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
383 op.getTaskReductionSyms())
384 result = todo("task_reduction");
385 };
386 auto checkUntied = [&todo](auto op, LogicalResult &result) {
387 if (op.getUntied())
388 result = todo("untied");
389 };
390
391 LogicalResult result = success();
392 llvm::TypeSwitch<Operation &>(op)
393 .Case([&](omp::CancelOp op) { checkCancelDirective(op, result); })
394 .Case([&](omp::CancellationPointOp op) {
395 checkCancelDirective(op, result);
396 })
397 .Case([&](omp::DistributeOp op) {
398 checkAllocate(op, result);
399 checkDistSchedule(op, result);
400 checkOrder(op, result);
401 })
402 .Case([&](omp::OrderedRegionOp op) { checkParLevelSimd(op, result); })
403 .Case([&](omp::SectionsOp op) {
404 checkAllocate(op, result);
405 checkPrivate(op, result);
406 checkReduction(op, result);
407 })
408 .Case([&](omp::SingleOp op) {
409 checkAllocate(op, result);
410 checkPrivate(op, result);
411 })
412 .Case([&](omp::TeamsOp op) {
413 checkAllocate(op, result);
414 checkPrivate(op, result);
415 })
416 .Case([&](omp::TaskOp op) {
417 checkAllocate(op, result);
418 checkInReduction(op, result);
419 })
420 .Case([&](omp::TaskgroupOp op) {
421 checkAllocate(op, result);
422 checkTaskReduction(op, result);
423 })
424 .Case([&](omp::TaskwaitOp op) {
425 checkDepend(op, result);
426 checkNowait(op, result);
427 })
428 .Case([&](omp::TaskloopOp op) {
429 // TODO: Add other clauses check
430 checkUntied(op, result);
431 checkPriority(op, result);
432 })
433 .Case([&](omp::WsloopOp op) {
434 checkAllocate(op, result);
435 checkLinear(op, result);
436 checkOrder(op, result);
437 checkReduction(op, result);
438 })
439 .Case([&](omp::ParallelOp op) {
440 checkAllocate(op, result);
441 checkReduction(op, result);
442 })
443 .Case([&](omp::SimdOp op) {
444 checkLinear(op, result);
445 checkReduction(op, result);
446 })
447 .Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
448 omp::AtomicCaptureOp>([&](auto op) { checkHint(op, result); })
449 .Case<omp::TargetEnterDataOp, omp::TargetExitDataOp, omp::TargetUpdateOp>(
450 [&](auto op) { checkDepend(op, result); })
451 .Case([&](omp::TargetOp op) {
452 checkAllocate(op, result);
453 checkBare(op, result);
454 checkDevice(op, result);
455 checkInReduction(op, result);
456 checkIsDevicePtr(op, result);
457 checkPrivate(op, result);
458 })
459 .Default([](Operation &) {
460 // Assume all clauses for an operation can be translated unless they are
461 // checked above.
462 });
463 return result;
464}
465
466static LogicalResult handleError(llvm::Error error, Operation &op) {
467 LogicalResult result = success();
468 if (error) {
469 llvm::handleAllErrors(
470 E: std::move(error),
471 Handlers: [&](const PreviouslyReportedError &) { result = failure(); },
472 Handlers: [&](const llvm::ErrorInfoBase &err) {
473 result = op.emitError(message: err.message());
474 });
475 }
476 return result;
477}
478
479template <typename T>
480static LogicalResult handleError(llvm::Expected<T> &result, Operation &op) {
481 if (!result)
482 return handleError(result.takeError(), op);
483
484 return success();
485}
486
487/// Find the insertion point for allocas given the current insertion point for
488/// normal operations in the builder.
489static llvm::OpenMPIRBuilder::InsertPointTy
490findAllocaInsertPoint(llvm::IRBuilderBase &builder,
491 LLVM::ModuleTranslation &moduleTranslation) {
492 // If there is an alloca insertion point on stack, i.e. we are in a nested
493 // operation and a specific point was provided by some surrounding operation,
494 // use it.
495 llvm::OpenMPIRBuilder::InsertPointTy allocaInsertPoint;
496 WalkResult walkResult = moduleTranslation.stackWalk<OpenMPAllocaStackFrame>(
497 callback: [&](OpenMPAllocaStackFrame &frame) {
498 allocaInsertPoint = frame.allocaInsertPoint;
499 return WalkResult::interrupt();
500 });
501 if (walkResult.wasInterrupted())
502 return allocaInsertPoint;
503
504 // Otherwise, insert to the entry block of the surrounding function.
505 // If the current IRBuilder InsertPoint is the function's entry, it cannot
506 // also be used for alloca insertion which would result in insertion order
507 // confusion. Create a new BasicBlock for the Builder and use the entry block
508 // for the allocs.
509 // TODO: Create a dedicated alloca BasicBlock at function creation such that
510 // we do not need to move the current InertPoint here.
511 if (builder.GetInsertBlock() ==
512 &builder.GetInsertBlock()->getParent()->getEntryBlock()) {
513 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end() &&
514 "Assuming end of basic block");
515 llvm::BasicBlock *entryBB = llvm::BasicBlock::Create(
516 Context&: builder.getContext(), Name: "entry", Parent: builder.GetInsertBlock()->getParent(),
517 InsertBefore: builder.GetInsertBlock()->getNextNode());
518 builder.CreateBr(Dest: entryBB);
519 builder.SetInsertPoint(entryBB);
520 }
521
522 llvm::BasicBlock &funcEntryBlock =
523 builder.GetInsertBlock()->getParent()->getEntryBlock();
524 return llvm::OpenMPIRBuilder::InsertPointTy(
525 &funcEntryBlock, funcEntryBlock.getFirstInsertionPt());
526}
527
528/// Find the loop information structure for the loop nest being translated. It
529/// will return a `null` value unless called from the translation function for
530/// a loop wrapper operation after successfully translating its body.
531static llvm::CanonicalLoopInfo *
532findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation) {
533 llvm::CanonicalLoopInfo *loopInfo = nullptr;
534 moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
535 callback: [&](OpenMPLoopInfoStackFrame &frame) {
536 loopInfo = frame.loopInfo;
537 return WalkResult::interrupt();
538 });
539 return loopInfo;
540}
541
542/// Converts the given region that appears within an OpenMP dialect operation to
543/// LLVM IR, creating a branch from the `sourceBlock` to the entry block of the
544/// region, and a branch from any block with an successor-less OpenMP terminator
545/// to `continuationBlock`. Populates `continuationBlockPHIs` with the PHI nodes
546/// of the continuation block if provided.
547static llvm::Expected<llvm::BasicBlock *> convertOmpOpRegions(
548 Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
549 LLVM::ModuleTranslation &moduleTranslation,
550 SmallVectorImpl<llvm::PHINode *> *continuationBlockPHIs = nullptr) {
551 bool isLoopWrapper = isa<omp::LoopWrapperInterface>(region.getParentOp());
552
553 llvm::BasicBlock *continuationBlock =
554 splitBB(Builder&: builder, CreateBranch: true, Name: "omp.region.cont");
555 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
556
557 llvm::LLVMContext &llvmContext = builder.getContext();
558 for (Block &bb : region) {
559 llvm::BasicBlock *llvmBB = llvm::BasicBlock::Create(
560 Context&: llvmContext, Name: blockName, Parent: builder.GetInsertBlock()->getParent(),
561 InsertBefore: builder.GetInsertBlock()->getNextNode());
562 moduleTranslation.mapBlock(mlir: &bb, llvm: llvmBB);
563 }
564
565 llvm::Instruction *sourceTerminator = sourceBlock->getTerminator();
566
567 // Terminators (namely YieldOp) may be forwarding values to the region that
568 // need to be available in the continuation block. Collect the types of these
569 // operands in preparation of creating PHI nodes. This is skipped for loop
570 // wrapper operations, for which we know in advance they have no terminators.
571 SmallVector<llvm::Type *> continuationBlockPHITypes;
572 unsigned numYields = 0;
573
574 if (!isLoopWrapper) {
575 bool operandsProcessed = false;
576 for (Block &bb : region.getBlocks()) {
577 if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
578 if (!operandsProcessed) {
579 for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
580 continuationBlockPHITypes.push_back(
581 moduleTranslation.convertType(yield->getOperand(i).getType()));
582 }
583 operandsProcessed = true;
584 } else {
585 assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
586 "mismatching number of values yielded from the region");
587 for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
588 llvm::Type *operandType =
589 moduleTranslation.convertType(yield->getOperand(i).getType());
590 (void)operandType;
591 assert(continuationBlockPHITypes[i] == operandType &&
592 "values of mismatching types yielded from the region");
593 }
594 }
595 numYields++;
596 }
597 }
598 }
599
600 // Insert PHI nodes in the continuation block for any values forwarded by the
601 // terminators in this region.
602 if (!continuationBlockPHITypes.empty())
603 assert(
604 continuationBlockPHIs &&
605 "expected continuation block PHIs if converted regions yield values");
606 if (continuationBlockPHIs) {
607 llvm::IRBuilderBase::InsertPointGuard guard(builder);
608 continuationBlockPHIs->reserve(N: continuationBlockPHITypes.size());
609 builder.SetInsertPoint(TheBB: continuationBlock, IP: continuationBlock->begin());
610 for (llvm::Type *ty : continuationBlockPHITypes)
611 continuationBlockPHIs->push_back(Elt: builder.CreatePHI(Ty: ty, NumReservedValues: numYields));
612 }
613
614 // Convert blocks one by one in topological order to ensure
615 // defs are converted before uses.
616 SetVector<Block *> blocks = getBlocksSortedByDominance(region);
617 for (Block *bb : blocks) {
618 llvm::BasicBlock *llvmBB = moduleTranslation.lookupBlock(block: bb);
619 // Retarget the branch of the entry block to the entry block of the
620 // converted region (regions are single-entry).
621 if (bb->isEntryBlock()) {
622 assert(sourceTerminator->getNumSuccessors() == 1 &&
623 "provided entry block has multiple successors");
624 assert(sourceTerminator->getSuccessor(0) == continuationBlock &&
625 "ContinuationBlock is not the successor of the entry block");
626 sourceTerminator->setSuccessor(Idx: 0, BB: llvmBB);
627 }
628
629 llvm::IRBuilderBase::InsertPointGuard guard(builder);
630 if (failed(
631 Result: moduleTranslation.convertBlock(bb&: *bb, ignoreArguments: bb->isEntryBlock(), builder)))
632 return llvm::make_error<PreviouslyReportedError>();
633
634 // Create a direct branch here for loop wrappers to prevent their lack of a
635 // terminator from causing a crash below.
636 if (isLoopWrapper) {
637 builder.CreateBr(Dest: continuationBlock);
638 continue;
639 }
640
641 // Special handling for `omp.yield` and `omp.terminator` (we may have more
642 // than one): they return the control to the parent OpenMP dialect operation
643 // so replace them with the branch to the continuation block. We handle this
644 // here to avoid relying inter-function communication through the
645 // ModuleTranslation class to set up the correct insertion point. This is
646 // also consistent with MLIR's idiom of handling special region terminators
647 // in the same code that handles the region-owning operation.
648 Operation *terminator = bb->getTerminator();
649 if (isa<omp::TerminatorOp, omp::YieldOp>(terminator)) {
650 builder.CreateBr(Dest: continuationBlock);
651
652 for (unsigned i = 0, e = terminator->getNumOperands(); i < e; ++i)
653 (*continuationBlockPHIs)[i]->addIncoming(
654 V: moduleTranslation.lookupValue(value: terminator->getOperand(idx: i)), BB: llvmBB);
655 }
656 }
657 // After all blocks have been traversed and values mapped, connect the PHI
658 // nodes to the results of preceding blocks.
659 LLVM::detail::connectPHINodes(region, state: moduleTranslation);
660
661 // Remove the blocks and values defined in this region from the mapping since
662 // they are not visible outside of this region. This allows the same region to
663 // be converted several times, that is cloned, without clashes, and slightly
664 // speeds up the lookups.
665 moduleTranslation.forgetMapping(region);
666
667 return continuationBlock;
668}
669
670/// Convert ProcBindKind from MLIR-generated enum to LLVM enum.
671static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind) {
672 switch (kind) {
673 case omp::ClauseProcBindKind::Close:
674 return llvm::omp::ProcBindKind::OMP_PROC_BIND_close;
675 case omp::ClauseProcBindKind::Master:
676 return llvm::omp::ProcBindKind::OMP_PROC_BIND_master;
677 case omp::ClauseProcBindKind::Primary:
678 return llvm::omp::ProcBindKind::OMP_PROC_BIND_primary;
679 case omp::ClauseProcBindKind::Spread:
680 return llvm::omp::ProcBindKind::OMP_PROC_BIND_spread;
681 }
682 llvm_unreachable("Unknown ClauseProcBindKind kind");
683}
684
685/// Maps block arguments from \p blockArgIface (which are MLIR values) to the
686/// corresponding LLVM values of \p the interface's operands. This is useful
687/// when an OpenMP region with entry block arguments is converted to LLVM. In
688/// this case the block arguments are (part of) of the OpenMP region's entry
689/// arguments and the operands are (part of) of the operands to the OpenMP op
690/// containing the region.
691static void forwardArgs(LLVM::ModuleTranslation &moduleTranslation,
692 omp::BlockArgOpenMPOpInterface blockArgIface) {
693 llvm::SmallVector<std::pair<Value, BlockArgument>> blockArgsPairs;
694 blockArgIface.getBlockArgsPairs(blockArgsPairs);
695 for (auto [var, arg] : blockArgsPairs)
696 moduleTranslation.mapValue(mlir: arg, llvm: moduleTranslation.lookupValue(value: var));
697}
698
699/// Helper function to map block arguments defined by ignored loop wrappers to
700/// LLVM values and prevent any uses of those from triggering null pointer
701/// dereferences.
702///
703/// This must be called after block arguments of parent wrappers have already
704/// been mapped to LLVM IR values.
705static LogicalResult
706convertIgnoredWrapper(omp::LoopWrapperInterface opInst,
707 LLVM::ModuleTranslation &moduleTranslation) {
708 // Map block arguments directly to the LLVM value associated to the
709 // corresponding operand. This is semantically equivalent to this wrapper not
710 // being present.
711 return llvm::TypeSwitch<Operation *, LogicalResult>(opInst)
712 .Case(caseFn: [&](omp::SimdOp op) {
713 forwardArgs(moduleTranslation,
714 cast<omp::BlockArgOpenMPOpInterface>(*op));
715 op.emitWarning() << "simd information on composite construct discarded";
716 return success();
717 })
718 .Default(defaultFn: [&](Operation *op) {
719 return op->emitError() << "cannot ignore wrapper";
720 });
721}
722
723/// Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
724static LogicalResult
725convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder,
726 LLVM::ModuleTranslation &moduleTranslation) {
727 auto maskedOp = cast<omp::MaskedOp>(opInst);
728 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
729
730 if (failed(Result: checkImplementationStatus(op&: opInst)))
731 return failure();
732
733 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
734 // MaskedOp has only one region associated with it.
735 auto &region = maskedOp.getRegion();
736 builder.restoreIP(IP: codeGenIP);
737 return convertOmpOpRegions(region, "omp.masked.region", builder,
738 moduleTranslation)
739 .takeError();
740 };
741
742 // TODO: Perform finalization actions for variables. This has to be
743 // called for variables which have destructors/finalizers.
744 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
745
746 llvm::Value *filterVal = nullptr;
747 if (auto filterVar = maskedOp.getFilteredThreadId()) {
748 filterVal = moduleTranslation.lookupValue(value: filterVar);
749 } else {
750 llvm::LLVMContext &llvmContext = builder.getContext();
751 filterVal =
752 llvm::ConstantInt::get(Ty: llvm::Type::getInt32Ty(C&: llvmContext), /*V=*/0);
753 }
754 assert(filterVal != nullptr);
755 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
756 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
757 moduleTranslation.getOpenMPBuilder()->createMasked(ompLoc, bodyGenCB,
758 finiCB, filterVal);
759
760 if (failed(Result: handleError(result&: afterIP, op&: opInst)))
761 return failure();
762
763 builder.restoreIP(IP: *afterIP);
764 return success();
765}
766
767/// Converts an OpenMP 'master' operation into LLVM IR using OpenMPIRBuilder.
768static LogicalResult
769convertOmpMaster(Operation &opInst, llvm::IRBuilderBase &builder,
770 LLVM::ModuleTranslation &moduleTranslation) {
771 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
772 auto masterOp = cast<omp::MasterOp>(opInst);
773
774 if (failed(Result: checkImplementationStatus(op&: opInst)))
775 return failure();
776
777 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
778 // MasterOp has only one region associated with it.
779 auto &region = masterOp.getRegion();
780 builder.restoreIP(IP: codeGenIP);
781 return convertOmpOpRegions(region, "omp.master.region", builder,
782 moduleTranslation)
783 .takeError();
784 };
785
786 // TODO: Perform finalization actions for variables. This has to be
787 // called for variables which have destructors/finalizers.
788 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
789
790 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
791 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
792 moduleTranslation.getOpenMPBuilder()->createMaster(ompLoc, bodyGenCB,
793 finiCB);
794
795 if (failed(Result: handleError(result&: afterIP, op&: opInst)))
796 return failure();
797
798 builder.restoreIP(IP: *afterIP);
799 return success();
800}
801
802/// Converts an OpenMP 'critical' operation into LLVM IR using OpenMPIRBuilder.
803static LogicalResult
804convertOmpCritical(Operation &opInst, llvm::IRBuilderBase &builder,
805 LLVM::ModuleTranslation &moduleTranslation) {
806 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
807 auto criticalOp = cast<omp::CriticalOp>(opInst);
808
809 if (failed(Result: checkImplementationStatus(op&: opInst)))
810 return failure();
811
812 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
813 // CriticalOp has only one region associated with it.
814 auto &region = cast<omp::CriticalOp>(opInst).getRegion();
815 builder.restoreIP(IP: codeGenIP);
816 return convertOmpOpRegions(region, "omp.critical.region", builder,
817 moduleTranslation)
818 .takeError();
819 };
820
821 // TODO: Perform finalization actions for variables. This has to be
822 // called for variables which have destructors/finalizers.
823 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
824
825 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
826 llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext();
827 llvm::Constant *hint = nullptr;
828
829 // If it has a name, it probably has a hint too.
830 if (criticalOp.getNameAttr()) {
831 // The verifiers in OpenMP Dialect guarentee that all the pointers are
832 // non-null
833 auto symbolRef = cast<SymbolRefAttr>(criticalOp.getNameAttr());
834 auto criticalDeclareOp =
835 SymbolTable::lookupNearestSymbolFrom<omp::CriticalDeclareOp>(criticalOp,
836 symbolRef);
837 hint =
838 llvm::ConstantInt::get(Ty: llvm::Type::getInt32Ty(C&: llvmContext),
839 V: static_cast<int>(criticalDeclareOp.getHint()));
840 }
841 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
842 moduleTranslation.getOpenMPBuilder()->createCritical(
843 Loc: ompLoc, BodyGenCB: bodyGenCB, FiniCB: finiCB, CriticalName: criticalOp.getName().value_or(""), HintInst: hint);
844
845 if (failed(Result: handleError(result&: afterIP, op&: opInst)))
846 return failure();
847
848 builder.restoreIP(IP: *afterIP);
849 return success();
850}
851
852/// A util to collect info needed to convert delayed privatizers from MLIR to
853/// LLVM.
854struct PrivateVarsInfo {
855 template <typename OP>
856 PrivateVarsInfo(OP op)
857 : blockArgs(
858 cast<omp::BlockArgOpenMPOpInterface>(*op).getPrivateBlockArgs()) {
859 mlirVars.reserve(N: blockArgs.size());
860 llvmVars.reserve(N: blockArgs.size());
861 collectPrivatizationDecls<OP>(op);
862
863 for (mlir::Value privateVar : op.getPrivateVars())
864 mlirVars.push_back(Elt: privateVar);
865 }
866
867 MutableArrayRef<BlockArgument> blockArgs;
868 SmallVector<mlir::Value> mlirVars;
869 SmallVector<llvm::Value *> llvmVars;
870 SmallVector<omp::PrivateClauseOp> privatizers;
871
872private:
873 /// Populates `privatizations` with privatization declarations used for the
874 /// given op.
875 template <class OP>
876 void collectPrivatizationDecls(OP op) {
877 std::optional<ArrayAttr> attr = op.getPrivateSyms();
878 if (!attr)
879 return;
880
881 privatizers.reserve(privatizers.size() + attr->size());
882 for (auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
883 privatizers.push_back(findPrivatizer(op, symbolRef));
884 }
885 }
886};
887
888/// Populates `reductions` with reduction declarations used in the given op.
889template <typename T>
890static void
891collectReductionDecls(T op,
892 SmallVectorImpl<omp::DeclareReductionOp> &reductions) {
893 std::optional<ArrayAttr> attr = op.getReductionSyms();
894 if (!attr)
895 return;
896
897 reductions.reserve(reductions.size() + op.getNumReductionVars());
898 for (auto symbolRef : attr->getAsRange<SymbolRefAttr>()) {
899 reductions.push_back(
900 SymbolTable::lookupNearestSymbolFrom<omp::DeclareReductionOp>(
901 op, symbolRef));
902 }
903}
904
905/// Translates the blocks contained in the given region and appends them to at
906/// the current insertion point of `builder`. The operations of the entry block
907/// are appended to the current insertion block. If set, `continuationBlockArgs`
908/// is populated with translated values that correspond to the values
909/// omp.yield'ed from the region.
910static LogicalResult inlineConvertOmpRegions(
911 Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
912 LLVM::ModuleTranslation &moduleTranslation,
913 SmallVectorImpl<llvm::Value *> *continuationBlockArgs = nullptr) {
914 if (region.empty())
915 return success();
916
917 // Special case for single-block regions that don't create additional blocks:
918 // insert operations without creating additional blocks.
919 if (llvm::hasSingleElement(C&: region)) {
920 llvm::Instruction *potentialTerminator =
921 builder.GetInsertBlock()->empty() ? nullptr
922 : &builder.GetInsertBlock()->back();
923
924 if (potentialTerminator && potentialTerminator->isTerminator())
925 potentialTerminator->removeFromParent();
926 moduleTranslation.mapBlock(mlir: &region.front(), llvm: builder.GetInsertBlock());
927
928 if (failed(Result: moduleTranslation.convertBlock(
929 bb&: region.front(), /*ignoreArguments=*/true, builder)))
930 return failure();
931
932 // The continuation arguments are simply the translated terminator operands.
933 if (continuationBlockArgs)
934 llvm::append_range(
935 C&: *continuationBlockArgs,
936 R: moduleTranslation.lookupValues(values: region.front().back().getOperands()));
937
938 // Drop the mapping that is no longer necessary so that the same region can
939 // be processed multiple times.
940 moduleTranslation.forgetMapping(region);
941
942 if (potentialTerminator && potentialTerminator->isTerminator()) {
943 llvm::BasicBlock *block = builder.GetInsertBlock();
944 if (block->empty()) {
945 // this can happen for really simple reduction init regions e.g.
946 // %0 = llvm.mlir.constant(0 : i32) : i32
947 // omp.yield(%0 : i32)
948 // because the llvm.mlir.constant (MLIR op) isn't converted into any
949 // llvm op
950 potentialTerminator->insertInto(ParentBB: block, It: block->begin());
951 } else {
952 potentialTerminator->insertAfter(InsertPos: &block->back());
953 }
954 }
955
956 return success();
957 }
958
959 SmallVector<llvm::PHINode *> phis;
960 llvm::Expected<llvm::BasicBlock *> continuationBlock =
961 convertOmpOpRegions(region, blockName, builder, moduleTranslation, continuationBlockPHIs: &phis);
962
963 if (failed(Result: handleError(result&: continuationBlock, op&: *region.getParentOp())))
964 return failure();
965
966 if (continuationBlockArgs)
967 llvm::append_range(C&: *continuationBlockArgs, R&: phis);
968 builder.SetInsertPoint(TheBB: *continuationBlock,
969 IP: (*continuationBlock)->getFirstInsertionPt());
970 return success();
971}
972
973namespace {
974/// Owning equivalents of OpenMPIRBuilder::(Atomic)ReductionGen that are used to
975/// store lambdas with capture.
976using OwningReductionGen =
977 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
978 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Value *, llvm::Value *,
979 llvm::Value *&)>;
980using OwningAtomicReductionGen =
981 std::function<llvm::OpenMPIRBuilder::InsertPointOrErrorTy(
982 llvm::OpenMPIRBuilder::InsertPointTy, llvm::Type *, llvm::Value *,
983 llvm::Value *)>;
984} // namespace
985
986/// Create an OpenMPIRBuilder-compatible reduction generator for the given
987/// reduction declaration. The generator uses `builder` but ignores its
988/// insertion point.
989static OwningReductionGen
990makeReductionGen(omp::DeclareReductionOp decl, llvm::IRBuilderBase &builder,
991 LLVM::ModuleTranslation &moduleTranslation) {
992 // The lambda is mutable because we need access to non-const methods of decl
993 // (which aren't actually mutating it), and we must capture decl by-value to
994 // avoid the dangling reference after the parent function returns.
995 OwningReductionGen gen =
996 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint,
997 llvm::Value *lhs, llvm::Value *rhs,
998 llvm::Value *&result) mutable
999 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1000 moduleTranslation.mapValue(decl.getReductionLhsArg(), lhs);
1001 moduleTranslation.mapValue(decl.getReductionRhsArg(), rhs);
1002 builder.restoreIP(IP: insertPoint);
1003 SmallVector<llvm::Value *> phis;
1004 if (failed(inlineConvertOmpRegions(decl.getReductionRegion(),
1005 "omp.reduction.nonatomic.body", builder,
1006 moduleTranslation, &phis)))
1007 return llvm::createStringError(
1008 Fmt: "failed to inline `combiner` region of `omp.declare_reduction`");
1009 result = llvm::getSingleElement(C&: phis);
1010 return builder.saveIP();
1011 };
1012 return gen;
1013}
1014
1015/// Create an OpenMPIRBuilder-compatible atomic reduction generator for the
1016/// given reduction declaration. The generator uses `builder` but ignores its
1017/// insertion point. Returns null if there is no atomic region available in the
1018/// reduction declaration.
1019static OwningAtomicReductionGen
1020makeAtomicReductionGen(omp::DeclareReductionOp decl,
1021 llvm::IRBuilderBase &builder,
1022 LLVM::ModuleTranslation &moduleTranslation) {
1023 if (decl.getAtomicReductionRegion().empty())
1024 return OwningAtomicReductionGen();
1025
1026 // The lambda is mutable because we need access to non-const methods of decl
1027 // (which aren't actually mutating it), and we must capture decl by-value to
1028 // avoid the dangling reference after the parent function returns.
1029 OwningAtomicReductionGen atomicGen =
1030 [&, decl](llvm::OpenMPIRBuilder::InsertPointTy insertPoint, llvm::Type *,
1031 llvm::Value *lhs, llvm::Value *rhs) mutable
1032 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1033 moduleTranslation.mapValue(decl.getAtomicReductionLhsArg(), lhs);
1034 moduleTranslation.mapValue(decl.getAtomicReductionRhsArg(), rhs);
1035 builder.restoreIP(IP: insertPoint);
1036 SmallVector<llvm::Value *> phis;
1037 if (failed(inlineConvertOmpRegions(decl.getAtomicReductionRegion(),
1038 "omp.reduction.atomic.body", builder,
1039 moduleTranslation, &phis)))
1040 return llvm::createStringError(
1041 Fmt: "failed to inline `atomic` region of `omp.declare_reduction`");
1042 assert(phis.empty());
1043 return builder.saveIP();
1044 };
1045 return atomicGen;
1046}
1047
1048/// Converts an OpenMP 'ordered' operation into LLVM IR using OpenMPIRBuilder.
1049static LogicalResult
1050convertOmpOrdered(Operation &opInst, llvm::IRBuilderBase &builder,
1051 LLVM::ModuleTranslation &moduleTranslation) {
1052 auto orderedOp = cast<omp::OrderedOp>(opInst);
1053
1054 if (failed(Result: checkImplementationStatus(op&: opInst)))
1055 return failure();
1056
1057 omp::ClauseDepend dependType = *orderedOp.getDoacrossDependType();
1058 bool isDependSource = dependType == omp::ClauseDepend::dependsource;
1059 unsigned numLoops = *orderedOp.getDoacrossNumLoops();
1060 SmallVector<llvm::Value *> vecValues =
1061 moduleTranslation.lookupValues(values: orderedOp.getDoacrossDependVars());
1062
1063 size_t indexVecValues = 0;
1064 while (indexVecValues < vecValues.size()) {
1065 SmallVector<llvm::Value *> storeValues;
1066 storeValues.reserve(N: numLoops);
1067 for (unsigned i = 0; i < numLoops; i++) {
1068 storeValues.push_back(Elt: vecValues[indexVecValues]);
1069 indexVecValues++;
1070 }
1071 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1072 findAllocaInsertPoint(builder, moduleTranslation);
1073 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1074 builder.restoreIP(IP: moduleTranslation.getOpenMPBuilder()->createOrderedDepend(
1075 Loc: ompLoc, AllocaIP: allocaIP, NumLoops: numLoops, StoreValues: storeValues, Name: ".cnt.addr", IsDependSource: isDependSource));
1076 }
1077 return success();
1078}
1079
1080/// Converts an OpenMP 'ordered_region' operation into LLVM IR using
1081/// OpenMPIRBuilder.
1082static LogicalResult
1083convertOmpOrderedRegion(Operation &opInst, llvm::IRBuilderBase &builder,
1084 LLVM::ModuleTranslation &moduleTranslation) {
1085 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1086 auto orderedRegionOp = cast<omp::OrderedRegionOp>(opInst);
1087
1088 if (failed(Result: checkImplementationStatus(op&: opInst)))
1089 return failure();
1090
1091 auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1092 // OrderedOp has only one region associated with it.
1093 auto &region = cast<omp::OrderedRegionOp>(opInst).getRegion();
1094 builder.restoreIP(IP: codeGenIP);
1095 return convertOmpOpRegions(region, "omp.ordered.region", builder,
1096 moduleTranslation)
1097 .takeError();
1098 };
1099
1100 // TODO: Perform finalization actions for variables. This has to be
1101 // called for variables which have destructors/finalizers.
1102 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
1103
1104 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1105 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1106 moduleTranslation.getOpenMPBuilder()->createOrderedThreadsSimd(
1107 Loc: ompLoc, BodyGenCB: bodyGenCB, FiniCB: finiCB, IsThreads: !orderedRegionOp.getParLevelSimd());
1108
1109 if (failed(Result: handleError(result&: afterIP, op&: opInst)))
1110 return failure();
1111
1112 builder.restoreIP(IP: *afterIP);
1113 return success();
1114}
1115
1116namespace {
1117/// Contains the arguments for an LLVM store operation
1118struct DeferredStore {
1119 DeferredStore(llvm::Value *value, llvm::Value *address)
1120 : value(value), address(address) {}
1121
1122 llvm::Value *value;
1123 llvm::Value *address;
1124};
1125} // namespace
1126
1127/// Allocate space for privatized reduction variables.
1128/// `deferredStores` contains information to create store operations which needs
1129/// to be inserted after all allocas
1130template <typename T>
1131static LogicalResult
1132allocReductionVars(T loop, ArrayRef<BlockArgument> reductionArgs,
1133 llvm::IRBuilderBase &builder,
1134 LLVM::ModuleTranslation &moduleTranslation,
1135 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1136 SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
1137 SmallVectorImpl<llvm::Value *> &privateReductionVariables,
1138 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1139 SmallVectorImpl<DeferredStore> &deferredStores,
1140 llvm::ArrayRef<bool> isByRefs) {
1141 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1142 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1143
1144 // delay creating stores until after all allocas
1145 deferredStores.reserve(N: loop.getNumReductionVars());
1146
1147 for (std::size_t i = 0; i < loop.getNumReductionVars(); ++i) {
1148 Region &allocRegion = reductionDecls[i].getAllocRegion();
1149 if (isByRefs[i]) {
1150 if (allocRegion.empty())
1151 continue;
1152
1153 SmallVector<llvm::Value *, 1> phis;
1154 if (failed(Result: inlineConvertOmpRegions(region&: allocRegion, blockName: "omp.reduction.alloc",
1155 builder, moduleTranslation, continuationBlockArgs: &phis)))
1156 return loop.emitError(
1157 "failed to inline `alloc` region of `omp.declare_reduction`");
1158
1159 assert(phis.size() == 1 && "expected one allocation to be yielded");
1160 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1161
1162 // Allocate reduction variable (which is a pointer to the real reduction
1163 // variable allocated in the inlined region)
1164 llvm::Value *var = builder.CreateAlloca(
1165 moduleTranslation.convertType(type: reductionDecls[i].getType()));
1166
1167 llvm::Type *ptrTy = builder.getPtrTy();
1168 llvm::Value *castVar =
1169 builder.CreatePointerBitCastOrAddrSpaceCast(V: var, DestTy: ptrTy);
1170 llvm::Value *castPhi =
1171 builder.CreatePointerBitCastOrAddrSpaceCast(V: phis[0], DestTy: ptrTy);
1172
1173 deferredStores.emplace_back(Args&: castPhi, Args&: castVar);
1174
1175 privateReductionVariables[i] = castVar;
1176 moduleTranslation.mapValue(mlir: reductionArgs[i], llvm: castPhi);
1177 reductionVariableMap.try_emplace(loop.getReductionVars()[i], castPhi);
1178 } else {
1179 assert(allocRegion.empty() &&
1180 "allocaction is implicit for by-val reduction");
1181 llvm::Value *var = builder.CreateAlloca(
1182 moduleTranslation.convertType(type: reductionDecls[i].getType()));
1183
1184 llvm::Type *ptrTy = builder.getPtrTy();
1185 llvm::Value *castVar =
1186 builder.CreatePointerBitCastOrAddrSpaceCast(V: var, DestTy: ptrTy);
1187
1188 moduleTranslation.mapValue(mlir: reductionArgs[i], llvm: castVar);
1189 privateReductionVariables[i] = castVar;
1190 reductionVariableMap.try_emplace(loop.getReductionVars()[i], castVar);
1191 }
1192 }
1193
1194 return success();
1195}
1196
1197/// Map input arguments to reduction initialization region
1198template <typename T>
1199static void
1200mapInitializationArgs(T loop, LLVM::ModuleTranslation &moduleTranslation,
1201 SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
1202 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1203 unsigned i) {
1204 // map input argument to the initialization region
1205 mlir::omp::DeclareReductionOp &reduction = reductionDecls[i];
1206 Region &initializerRegion = reduction.getInitializerRegion();
1207 Block &entry = initializerRegion.front();
1208
1209 mlir::Value mlirSource = loop.getReductionVars()[i];
1210 llvm::Value *llvmSource = moduleTranslation.lookupValue(value: mlirSource);
1211 assert(llvmSource && "lookup reduction var");
1212 moduleTranslation.mapValue(reduction.getInitializerMoldArg(), llvmSource);
1213
1214 if (entry.getNumArguments() > 1) {
1215 llvm::Value *allocation =
1216 reductionVariableMap.lookup(Val: loop.getReductionVars()[i]);
1217 moduleTranslation.mapValue(reduction.getInitializerAllocArg(), allocation);
1218 }
1219}
1220
1221static void
1222setInsertPointForPossiblyEmptyBlock(llvm::IRBuilderBase &builder,
1223 llvm::BasicBlock *block = nullptr) {
1224 if (block == nullptr)
1225 block = builder.GetInsertBlock();
1226
1227 if (block->empty() || block->getTerminator() == nullptr)
1228 builder.SetInsertPoint(block);
1229 else
1230 builder.SetInsertPoint(block->getTerminator());
1231}
1232
1233/// Inline reductions' `init` regions. This functions assumes that the
1234/// `builder`'s insertion point is where the user wants the `init` regions to be
1235/// inlined; i.e. it does not try to find a proper insertion location for the
1236/// `init` regions. It also leaves the `builder's insertions point in a state
1237/// where the user can continue the code-gen directly afterwards.
1238template <typename OP>
1239static LogicalResult
1240initReductionVars(OP op, ArrayRef<BlockArgument> reductionArgs,
1241 llvm::IRBuilderBase &builder,
1242 LLVM::ModuleTranslation &moduleTranslation,
1243 llvm::BasicBlock *latestAllocaBlock,
1244 SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
1245 SmallVectorImpl<llvm::Value *> &privateReductionVariables,
1246 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1247 llvm::ArrayRef<bool> isByRef,
1248 SmallVectorImpl<DeferredStore> &deferredStores) {
1249 if (op.getNumReductionVars() == 0)
1250 return success();
1251
1252 llvm::BasicBlock *initBlock = splitBB(Builder&: builder, CreateBranch: true, Name: "omp.reduction.init");
1253 auto allocaIP = llvm::IRBuilderBase::InsertPoint(
1254 latestAllocaBlock, latestAllocaBlock->getTerminator()->getIterator());
1255 builder.restoreIP(IP: allocaIP);
1256 SmallVector<llvm::Value *> byRefVars(op.getNumReductionVars());
1257
1258 for (unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1259 if (isByRef[i]) {
1260 if (!reductionDecls[i].getAllocRegion().empty())
1261 continue;
1262
1263 // TODO: remove after all users of by-ref are updated to use the alloc
1264 // region: Allocate reduction variable (which is a pointer to the real
1265 // reduciton variable allocated in the inlined region)
1266 byRefVars[i] = builder.CreateAlloca(
1267 moduleTranslation.convertType(type: reductionDecls[i].getType()));
1268 }
1269 }
1270
1271 setInsertPointForPossiblyEmptyBlock(builder, block: initBlock);
1272
1273 // store result of the alloc region to the allocated pointer to the real
1274 // reduction variable
1275 for (auto [data, addr] : deferredStores)
1276 builder.CreateStore(Val: data, Ptr: addr);
1277
1278 // Before the loop, store the initial values of reductions into reduction
1279 // variables. Although this could be done after allocas, we don't want to mess
1280 // up with the alloca insertion point.
1281 for (unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1282 SmallVector<llvm::Value *, 1> phis;
1283
1284 // map block argument to initializer region
1285 mapInitializationArgs(op, moduleTranslation, reductionDecls,
1286 reductionVariableMap, i);
1287
1288 if (failed(inlineConvertOmpRegions(reductionDecls[i].getInitializerRegion(),
1289 "omp.reduction.neutral", builder,
1290 moduleTranslation, &phis)))
1291 return failure();
1292
1293 assert(phis.size() == 1 && "expected one value to be yielded from the "
1294 "reduction neutral element declaration region");
1295
1296 setInsertPointForPossiblyEmptyBlock(builder);
1297
1298 if (isByRef[i]) {
1299 if (!reductionDecls[i].getAllocRegion().empty())
1300 // done in allocReductionVars
1301 continue;
1302
1303 // TODO: this path can be removed once all users of by-ref are updated to
1304 // use an alloc region
1305
1306 // Store the result of the inlined region to the allocated reduction var
1307 // ptr
1308 builder.CreateStore(Val: phis[0], Ptr: byRefVars[i]);
1309
1310 privateReductionVariables[i] = byRefVars[i];
1311 moduleTranslation.mapValue(mlir: reductionArgs[i], llvm: phis[0]);
1312 reductionVariableMap.try_emplace(op.getReductionVars()[i], phis[0]);
1313 } else {
1314 // for by-ref case the store is inside of the reduction region
1315 builder.CreateStore(Val: phis[0], Ptr: privateReductionVariables[i]);
1316 // the rest was handled in allocByValReductionVars
1317 }
1318
1319 // forget the mapping for the initializer region because we might need a
1320 // different mapping if this reduction declaration is re-used for a
1321 // different variable
1322 moduleTranslation.forgetMapping(region&: reductionDecls[i].getInitializerRegion());
1323 }
1324
1325 return success();
1326}
1327
1328/// Collect reduction info
1329template <typename T>
1330static void collectReductionInfo(
1331 T loop, llvm::IRBuilderBase &builder,
1332 LLVM::ModuleTranslation &moduleTranslation,
1333 SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
1334 SmallVectorImpl<OwningReductionGen> &owningReductionGens,
1335 SmallVectorImpl<OwningAtomicReductionGen> &owningAtomicReductionGens,
1336 const ArrayRef<llvm::Value *> privateReductionVariables,
1337 SmallVectorImpl<llvm::OpenMPIRBuilder::ReductionInfo> &reductionInfos) {
1338 unsigned numReductions = loop.getNumReductionVars();
1339
1340 for (unsigned i = 0; i < numReductions; ++i) {
1341 owningReductionGens.push_back(
1342 makeReductionGen(reductionDecls[i], builder, moduleTranslation));
1343 owningAtomicReductionGens.push_back(
1344 makeAtomicReductionGen(reductionDecls[i], builder, moduleTranslation));
1345 }
1346
1347 // Collect the reduction information.
1348 reductionInfos.reserve(N: numReductions);
1349 for (unsigned i = 0; i < numReductions; ++i) {
1350 llvm::OpenMPIRBuilder::ReductionGenAtomicCBTy atomicGen = nullptr;
1351 if (owningAtomicReductionGens[i])
1352 atomicGen = owningAtomicReductionGens[i];
1353 llvm::Value *variable =
1354 moduleTranslation.lookupValue(value: loop.getReductionVars()[i]);
1355 reductionInfos.push_back(
1356 {moduleTranslation.convertType(type: reductionDecls[i].getType()), variable,
1357 privateReductionVariables[i],
1358 /*EvaluationKind=*/llvm::OpenMPIRBuilder::EvalKind::Scalar,
1359 owningReductionGens[i],
1360 /*ReductionGenClang=*/nullptr, atomicGen});
1361 }
1362}
1363
1364/// handling of DeclareReductionOp's cleanup region
1365static LogicalResult
1366inlineOmpRegionCleanup(llvm::SmallVectorImpl<Region *> &cleanupRegions,
1367 llvm::ArrayRef<llvm::Value *> privateVariables,
1368 LLVM::ModuleTranslation &moduleTranslation,
1369 llvm::IRBuilderBase &builder, StringRef regionName,
1370 bool shouldLoadCleanupRegionArg = true) {
1371 for (auto [i, cleanupRegion] : llvm::enumerate(First&: cleanupRegions)) {
1372 if (cleanupRegion->empty())
1373 continue;
1374
1375 // map the argument to the cleanup region
1376 Block &entry = cleanupRegion->front();
1377
1378 llvm::Instruction *potentialTerminator =
1379 builder.GetInsertBlock()->empty() ? nullptr
1380 : &builder.GetInsertBlock()->back();
1381 if (potentialTerminator && potentialTerminator->isTerminator())
1382 builder.SetInsertPoint(potentialTerminator);
1383 llvm::Value *privateVarValue =
1384 shouldLoadCleanupRegionArg
1385 ? builder.CreateLoad(
1386 Ty: moduleTranslation.convertType(type: entry.getArgument(i: 0).getType()),
1387 Ptr: privateVariables[i])
1388 : privateVariables[i];
1389
1390 moduleTranslation.mapValue(mlir: entry.getArgument(i: 0), llvm: privateVarValue);
1391
1392 if (failed(Result: inlineConvertOmpRegions(region&: *cleanupRegion, blockName: regionName, builder,
1393 moduleTranslation)))
1394 return failure();
1395
1396 // clear block argument mapping in case it needs to be re-created with a
1397 // different source for another use of the same reduction decl
1398 moduleTranslation.forgetMapping(region&: *cleanupRegion);
1399 }
1400 return success();
1401}
1402
1403// TODO: not used by ParallelOp
1404template <class OP>
1405static LogicalResult createReductionsAndCleanup(
1406 OP op, llvm::IRBuilderBase &builder,
1407 LLVM::ModuleTranslation &moduleTranslation,
1408 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1409 SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
1410 ArrayRef<llvm::Value *> privateReductionVariables, ArrayRef<bool> isByRef,
1411 bool isNowait = false, bool isTeamsReduction = false) {
1412 // Process the reductions if required.
1413 if (op.getNumReductionVars() == 0)
1414 return success();
1415
1416 SmallVector<OwningReductionGen> owningReductionGens;
1417 SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
1418 SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
1419
1420 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1421
1422 // Create the reduction generators. We need to own them here because
1423 // ReductionInfo only accepts references to the generators.
1424 collectReductionInfo(op, builder, moduleTranslation, reductionDecls,
1425 owningReductionGens, owningAtomicReductionGens,
1426 privateReductionVariables, reductionInfos);
1427
1428 // The call to createReductions below expects the block to have a
1429 // terminator. Create an unreachable instruction to serve as terminator
1430 // and remove it later.
1431 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
1432 builder.SetInsertPoint(tempTerminator);
1433 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
1434 ompBuilder->createReductions(Loc: builder.saveIP(), AllocaIP: allocaIP, ReductionInfos: reductionInfos,
1435 IsByRef: isByRef, IsNoWait: isNowait, IsTeamsReduction: isTeamsReduction);
1436
1437 if (failed(handleError(contInsertPoint, *op)))
1438 return failure();
1439
1440 if (!contInsertPoint->getBlock())
1441 return op->emitOpError() << "failed to convert reductions";
1442
1443 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1444 ompBuilder->createBarrier(*contInsertPoint, llvm::omp::OMPD_for);
1445
1446 if (failed(handleError(afterIP, *op)))
1447 return failure();
1448
1449 tempTerminator->eraseFromParent();
1450 builder.restoreIP(IP: *afterIP);
1451
1452 // after the construct, deallocate private reduction variables
1453 SmallVector<Region *> reductionRegions;
1454 llvm::transform(reductionDecls, std::back_inserter(x&: reductionRegions),
1455 [](omp::DeclareReductionOp reductionDecl) {
1456 return &reductionDecl.getCleanupRegion();
1457 });
1458 return inlineOmpRegionCleanup(cleanupRegions&: reductionRegions, privateVariables: privateReductionVariables,
1459 moduleTranslation, builder,
1460 regionName: "omp.reduction.cleanup");
1461 return success();
1462}
1463
1464static ArrayRef<bool> getIsByRef(std::optional<ArrayRef<bool>> attr) {
1465 if (!attr)
1466 return {};
1467 return *attr;
1468}
1469
1470// TODO: not used by omp.parallel
1471template <typename OP>
1472static LogicalResult allocAndInitializeReductionVars(
1473 OP op, ArrayRef<BlockArgument> reductionArgs, llvm::IRBuilderBase &builder,
1474 LLVM::ModuleTranslation &moduleTranslation,
1475 llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1476 SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
1477 SmallVectorImpl<llvm::Value *> &privateReductionVariables,
1478 DenseMap<Value, llvm::Value *> &reductionVariableMap,
1479 llvm::ArrayRef<bool> isByRef) {
1480 if (op.getNumReductionVars() == 0)
1481 return success();
1482
1483 SmallVector<DeferredStore> deferredStores;
1484
1485 if (failed(allocReductionVars(op, reductionArgs, builder, moduleTranslation,
1486 allocaIP, reductionDecls,
1487 privateReductionVariables, reductionVariableMap,
1488 deferredStores, isByRef)))
1489 return failure();
1490
1491 return initReductionVars(op, reductionArgs, builder, moduleTranslation,
1492 allocaIP.getBlock(), reductionDecls,
1493 privateReductionVariables, reductionVariableMap,
1494 isByRef, deferredStores);
1495}
1496
1497/// Return the llvm::Value * corresponding to the `privateVar` that
1498/// is being privatized. It isn't always as simple as looking up
1499/// moduleTranslation with privateVar. For instance, in case of
1500/// an allocatable, the descriptor for the allocatable is privatized.
1501/// This descriptor is mapped using an MapInfoOp. So, this function
1502/// will return a pointer to the llvm::Value corresponding to the
1503/// block argument for the mapped descriptor.
1504static llvm::Value *
1505findAssociatedValue(Value privateVar, llvm::IRBuilderBase &builder,
1506 LLVM::ModuleTranslation &moduleTranslation,
1507 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1508 if (mappedPrivateVars == nullptr || !mappedPrivateVars->contains(Val: privateVar))
1509 return moduleTranslation.lookupValue(value: privateVar);
1510
1511 Value blockArg = (*mappedPrivateVars)[privateVar];
1512 Type privVarType = privateVar.getType();
1513 Type blockArgType = blockArg.getType();
1514 assert(isa<LLVM::LLVMPointerType>(blockArgType) &&
1515 "A block argument corresponding to a mapped var should have "
1516 "!llvm.ptr type");
1517
1518 if (privVarType == blockArgType)
1519 return moduleTranslation.lookupValue(value: blockArg);
1520
1521 // This typically happens when the privatized type is lowered from
1522 // boxchar<KIND> and gets lowered to !llvm.struct<(ptr, i64)>. That is the
1523 // struct/pair is passed by value. But, mapped values are passed only as
1524 // pointers, so before we privatize, we must load the pointer.
1525 if (!isa<LLVM::LLVMPointerType>(privVarType))
1526 return builder.CreateLoad(Ty: moduleTranslation.convertType(type: privVarType),
1527 Ptr: moduleTranslation.lookupValue(value: blockArg));
1528
1529 return moduleTranslation.lookupValue(value: privateVar);
1530}
1531
1532/// Initialize a single (first)private variable. You probably want to use
1533/// allocateAndInitPrivateVars instead of this.
1534/// This returns the private variable which has been initialized. This
1535/// variable should be mapped before constructing the body of the Op.
1536static llvm::Expected<llvm::Value *> initPrivateVar(
1537 llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation,
1538 omp::PrivateClauseOp &privDecl, Value mlirPrivVar, BlockArgument &blockArg,
1539 llvm::Value *llvmPrivateVar, llvm::BasicBlock *privInitBlock,
1540 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1541 Region &initRegion = privDecl.getInitRegion();
1542 if (initRegion.empty())
1543 return llvmPrivateVar;
1544
1545 // map initialization region block arguments
1546 llvm::Value *nonPrivateVar = findAssociatedValue(
1547 privateVar: mlirPrivVar, builder, moduleTranslation, mappedPrivateVars);
1548 assert(nonPrivateVar);
1549 moduleTranslation.mapValue(privDecl.getInitMoldArg(), nonPrivateVar);
1550 moduleTranslation.mapValue(privDecl.getInitPrivateArg(), llvmPrivateVar);
1551
1552 // in-place convert the private initialization region
1553 SmallVector<llvm::Value *, 1> phis;
1554 if (failed(Result: inlineConvertOmpRegions(region&: initRegion, blockName: "omp.private.init", builder,
1555 moduleTranslation, continuationBlockArgs: &phis)))
1556 return llvm::createStringError(
1557 Fmt: "failed to inline `init` region of `omp.private`");
1558
1559 assert(phis.size() == 1 && "expected one allocation to be yielded");
1560
1561 // clear init region block argument mapping in case it needs to be
1562 // re-created with a different source for another use of the same
1563 // reduction decl
1564 moduleTranslation.forgetMapping(region&: initRegion);
1565
1566 // Prefer the value yielded from the init region to the allocated private
1567 // variable in case the region is operating on arguments by-value (e.g.
1568 // Fortran character boxes).
1569 return phis[0];
1570}
1571
1572static llvm::Error
1573initPrivateVars(llvm::IRBuilderBase &builder,
1574 LLVM::ModuleTranslation &moduleTranslation,
1575 PrivateVarsInfo &privateVarsInfo,
1576 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1577 if (privateVarsInfo.blockArgs.empty())
1578 return llvm::Error::success();
1579
1580 llvm::BasicBlock *privInitBlock = splitBB(Builder&: builder, CreateBranch: true, Name: "omp.private.init");
1581 setInsertPointForPossiblyEmptyBlock(builder, block: privInitBlock);
1582
1583 for (auto [idx, zip] : llvm::enumerate(llvm::zip_equal(
1584 privateVarsInfo.privatizers, privateVarsInfo.mlirVars,
1585 privateVarsInfo.blockArgs, privateVarsInfo.llvmVars))) {
1586 auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVar] = zip;
1587 llvm::Expected<llvm::Value *> privVarOrErr = initPrivateVar(
1588 builder, moduleTranslation, privDecl, mlirPrivVar, blockArg,
1589 llvmPrivateVar, privInitBlock, mappedPrivateVars);
1590
1591 if (!privVarOrErr)
1592 return privVarOrErr.takeError();
1593
1594 llvmPrivateVar = privVarOrErr.get();
1595 moduleTranslation.mapValue(blockArg, llvmPrivateVar);
1596
1597 setInsertPointForPossiblyEmptyBlock(builder);
1598 }
1599
1600 return llvm::Error::success();
1601}
1602
1603/// Allocate and initialize delayed private variables. Returns the basic block
1604/// which comes after all of these allocations. llvm::Value * for each of these
1605/// private variables are populated in llvmPrivateVars.
1606static llvm::Expected<llvm::BasicBlock *>
1607allocatePrivateVars(llvm::IRBuilderBase &builder,
1608 LLVM::ModuleTranslation &moduleTranslation,
1609 PrivateVarsInfo &privateVarsInfo,
1610 const llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
1611 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1612 // Allocate private vars
1613 llvm::Instruction *allocaTerminator = allocaIP.getBlock()->getTerminator();
1614 splitBB(IP: llvm::OpenMPIRBuilder::InsertPointTy(allocaIP.getBlock(),
1615 allocaTerminator->getIterator()),
1616 CreateBranch: true, DL: allocaTerminator->getStableDebugLoc(),
1617 Name: "omp.region.after_alloca");
1618
1619 llvm::IRBuilderBase::InsertPointGuard guard(builder);
1620 // Update the allocaTerminator since the alloca block was split above.
1621 allocaTerminator = allocaIP.getBlock()->getTerminator();
1622 builder.SetInsertPoint(allocaTerminator);
1623 // The new terminator is an uncondition branch created by the splitBB above.
1624 assert(allocaTerminator->getNumSuccessors() == 1 &&
1625 "This is an unconditional branch created by splitBB");
1626
1627 llvm::DataLayout dataLayout = builder.GetInsertBlock()->getDataLayout();
1628 llvm::BasicBlock *afterAllocas = allocaTerminator->getSuccessor(Idx: 0);
1629
1630 unsigned int allocaAS =
1631 moduleTranslation.getLLVMModule()->getDataLayout().getAllocaAddrSpace();
1632 unsigned int defaultAS = moduleTranslation.getLLVMModule()
1633 ->getDataLayout()
1634 .getProgramAddressSpace();
1635
1636 for (auto [privDecl, mlirPrivVar, blockArg] :
1637 llvm::zip_equal(privateVarsInfo.privatizers, privateVarsInfo.mlirVars,
1638 privateVarsInfo.blockArgs)) {
1639 llvm::Type *llvmAllocType =
1640 moduleTranslation.convertType(privDecl.getType());
1641 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
1642 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
1643 llvmAllocType, /*ArraySize=*/nullptr, "omp.private.alloc");
1644 if (allocaAS != defaultAS)
1645 llvmPrivateVar = builder.CreateAddrSpaceCast(llvmPrivateVar,
1646 builder.getPtrTy(defaultAS));
1647
1648 privateVarsInfo.llvmVars.push_back(llvmPrivateVar);
1649 }
1650
1651 return afterAllocas;
1652}
1653
1654static LogicalResult copyFirstPrivateVars(
1655 mlir::Operation *op, llvm::IRBuilderBase &builder,
1656 LLVM::ModuleTranslation &moduleTranslation,
1657 SmallVectorImpl<mlir::Value> &mlirPrivateVars,
1658 ArrayRef<llvm::Value *> llvmPrivateVars,
1659 SmallVectorImpl<omp::PrivateClauseOp> &privateDecls, bool insertBarrier,
1660 llvm::DenseMap<Value, Value> *mappedPrivateVars = nullptr) {
1661 // Apply copy region for firstprivate.
1662 bool needsFirstprivate =
1663 llvm::any_of(privateDecls, [](omp::PrivateClauseOp &privOp) {
1664 return privOp.getDataSharingType() ==
1665 omp::DataSharingClauseType::FirstPrivate;
1666 });
1667
1668 if (!needsFirstprivate)
1669 return success();
1670
1671 llvm::BasicBlock *copyBlock =
1672 splitBB(Builder&: builder, /*CreateBranch=*/true, Name: "omp.private.copy");
1673 setInsertPointForPossiblyEmptyBlock(builder, block: copyBlock);
1674
1675 for (auto [decl, mlirVar, llvmVar] :
1676 llvm::zip_equal(privateDecls, mlirPrivateVars, llvmPrivateVars)) {
1677 if (decl.getDataSharingType() != omp::DataSharingClauseType::FirstPrivate)
1678 continue;
1679
1680 // copyRegion implements `lhs = rhs`
1681 Region &copyRegion = decl.getCopyRegion();
1682
1683 // map copyRegion rhs arg
1684 llvm::Value *nonPrivateVar = findAssociatedValue(
1685 mlirVar, builder, moduleTranslation, mappedPrivateVars);
1686 assert(nonPrivateVar);
1687 moduleTranslation.mapValue(decl.getCopyMoldArg(), nonPrivateVar);
1688
1689 // map copyRegion lhs arg
1690 moduleTranslation.mapValue(decl.getCopyPrivateArg(), llvmVar);
1691
1692 // in-place convert copy region
1693 if (failed(inlineConvertOmpRegions(copyRegion, "omp.private.copy", builder,
1694 moduleTranslation)))
1695 return decl.emitError("failed to inline `copy` region of `omp.private`");
1696
1697 setInsertPointForPossiblyEmptyBlock(builder);
1698
1699 // ignore unused value yielded from copy region
1700
1701 // clear copy region block argument mapping in case it needs to be
1702 // re-created with different sources for reuse of the same reduction
1703 // decl
1704 moduleTranslation.forgetMapping(copyRegion);
1705 }
1706
1707 if (insertBarrier) {
1708 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1709 llvm::OpenMPIRBuilder::InsertPointOrErrorTy res =
1710 ompBuilder->createBarrier(builder.saveIP(), llvm::omp::OMPD_barrier);
1711 if (failed(Result: handleError(result&: res, op&: *op)))
1712 return failure();
1713 }
1714
1715 return success();
1716}
1717
1718static LogicalResult
1719cleanupPrivateVars(llvm::IRBuilderBase &builder,
1720 LLVM::ModuleTranslation &moduleTranslation, Location loc,
1721 SmallVectorImpl<llvm::Value *> &llvmPrivateVars,
1722 SmallVectorImpl<omp::PrivateClauseOp> &privateDecls) {
1723 // private variable deallocation
1724 SmallVector<Region *> privateCleanupRegions;
1725 llvm::transform(privateDecls, std::back_inserter(x&: privateCleanupRegions),
1726 [](omp::PrivateClauseOp privatizer) {
1727 return &privatizer.getDeallocRegion();
1728 });
1729
1730 if (failed(Result: inlineOmpRegionCleanup(
1731 cleanupRegions&: privateCleanupRegions, privateVariables: llvmPrivateVars, moduleTranslation, builder,
1732 regionName: "omp.private.dealloc", /*shouldLoadCleanupRegionArg=*/false)))
1733 return mlir::emitError(loc, message: "failed to inline `dealloc` region of an "
1734 "`omp.private` op in");
1735
1736 return success();
1737}
1738
1739/// Returns true if the construct contains omp.cancel or omp.cancellation_point
1740static bool constructIsCancellable(Operation *op) {
1741 // omp.cancel and omp.cancellation_point must be "closely nested" so they will
1742 // be visible and not inside of function calls. This is enforced by the
1743 // verifier.
1744 return op
1745 ->walk(callback: [](Operation *child) {
1746 if (mlir::isa<omp::CancelOp, omp::CancellationPointOp>(child))
1747 return WalkResult::interrupt();
1748 return WalkResult::advance();
1749 })
1750 .wasInterrupted();
1751}
1752
1753static LogicalResult
1754convertOmpSections(Operation &opInst, llvm::IRBuilderBase &builder,
1755 LLVM::ModuleTranslation &moduleTranslation) {
1756 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1757 using StorableBodyGenCallbackTy =
1758 llvm::OpenMPIRBuilder::StorableBodyGenCallbackTy;
1759
1760 auto sectionsOp = cast<omp::SectionsOp>(opInst);
1761
1762 if (failed(Result: checkImplementationStatus(op&: opInst)))
1763 return failure();
1764
1765 llvm::ArrayRef<bool> isByRef = getIsByRef(sectionsOp.getReductionByref());
1766 assert(isByRef.size() == sectionsOp.getNumReductionVars());
1767
1768 SmallVector<omp::DeclareReductionOp> reductionDecls;
1769 collectReductionDecls(sectionsOp, reductionDecls);
1770 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1771 findAllocaInsertPoint(builder, moduleTranslation);
1772
1773 SmallVector<llvm::Value *> privateReductionVariables(
1774 sectionsOp.getNumReductionVars());
1775 DenseMap<Value, llvm::Value *> reductionVariableMap;
1776
1777 MutableArrayRef<BlockArgument> reductionArgs =
1778 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
1779
1780 if (failed(allocAndInitializeReductionVars(
1781 sectionsOp, reductionArgs, builder, moduleTranslation, allocaIP,
1782 reductionDecls, privateReductionVariables, reductionVariableMap,
1783 isByRef)))
1784 return failure();
1785
1786 SmallVector<StorableBodyGenCallbackTy> sectionCBs;
1787
1788 for (Operation &op : *sectionsOp.getRegion().begin()) {
1789 auto sectionOp = dyn_cast<omp::SectionOp>(op);
1790 if (!sectionOp) // omp.terminator
1791 continue;
1792
1793 Region &region = sectionOp.getRegion();
1794 auto sectionCB = [&sectionsOp, &region, &builder, &moduleTranslation](
1795 InsertPointTy allocaIP, InsertPointTy codeGenIP) {
1796 builder.restoreIP(codeGenIP);
1797
1798 // map the omp.section reduction block argument to the omp.sections block
1799 // arguments
1800 // TODO: this assumes that the only block arguments are reduction
1801 // variables
1802 assert(region.getNumArguments() ==
1803 sectionsOp.getRegion().getNumArguments());
1804 for (auto [sectionsArg, sectionArg] : llvm::zip_equal(
1805 sectionsOp.getRegion().getArguments(), region.getArguments())) {
1806 llvm::Value *llvmVal = moduleTranslation.lookupValue(sectionsArg);
1807 assert(llvmVal);
1808 moduleTranslation.mapValue(sectionArg, llvmVal);
1809 }
1810
1811 return convertOmpOpRegions(region, "omp.section.region", builder,
1812 moduleTranslation)
1813 .takeError();
1814 };
1815 sectionCBs.push_back(sectionCB);
1816 }
1817
1818 // No sections within omp.sections operation - skip generation. This situation
1819 // is only possible if there is only a terminator operation inside the
1820 // sections operation
1821 if (sectionCBs.empty())
1822 return success();
1823
1824 assert(isa<omp::SectionOp>(*sectionsOp.getRegion().op_begin()));
1825
1826 // TODO: Perform appropriate actions according to the data-sharing
1827 // attribute (shared, private, firstprivate, ...) of variables.
1828 // Currently defaults to shared.
1829 auto privCB = [&](InsertPointTy, InsertPointTy codeGenIP, llvm::Value &,
1830 llvm::Value &vPtr, llvm::Value *&replacementValue)
1831 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
1832 replacementValue = &vPtr;
1833 return codeGenIP;
1834 };
1835
1836 // TODO: Perform finalization actions for variables. This has to be
1837 // called for variables which have destructors/finalizers.
1838 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
1839
1840 allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
1841 bool isCancellable = constructIsCancellable(sectionsOp);
1842 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1843 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1844 moduleTranslation.getOpenMPBuilder()->createSections(
1845 Loc: ompLoc, AllocaIP: allocaIP, SectionCBs: sectionCBs, PrivCB: privCB, FiniCB: finiCB, IsCancellable: isCancellable,
1846 IsNowait: sectionsOp.getNowait());
1847
1848 if (failed(Result: handleError(result&: afterIP, op&: opInst)))
1849 return failure();
1850
1851 builder.restoreIP(IP: *afterIP);
1852
1853 // Process the reductions if required.
1854 return createReductionsAndCleanup(
1855 sectionsOp, builder, moduleTranslation, allocaIP, reductionDecls,
1856 privateReductionVariables, isByRef, sectionsOp.getNowait());
1857}
1858
1859/// Converts an OpenMP single construct into LLVM IR using OpenMPIRBuilder.
1860static LogicalResult
1861convertOmpSingle(omp::SingleOp &singleOp, llvm::IRBuilderBase &builder,
1862 LLVM::ModuleTranslation &moduleTranslation) {
1863 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1864 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
1865
1866 if (failed(checkImplementationStatus(*singleOp)))
1867 return failure();
1868
1869 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
1870 builder.restoreIP(IP: codegenIP);
1871 return convertOmpOpRegions(singleOp.getRegion(), "omp.single.region",
1872 builder, moduleTranslation)
1873 .takeError();
1874 };
1875 auto finiCB = [&](InsertPointTy codeGenIP) { return llvm::Error::success(); };
1876
1877 // Handle copyprivate
1878 Operation::operand_range cpVars = singleOp.getCopyprivateVars();
1879 std::optional<ArrayAttr> cpFuncs = singleOp.getCopyprivateSyms();
1880 llvm::SmallVector<llvm::Value *> llvmCPVars;
1881 llvm::SmallVector<llvm::Function *> llvmCPFuncs;
1882 for (size_t i = 0, e = cpVars.size(); i < e; ++i) {
1883 llvmCPVars.push_back(Elt: moduleTranslation.lookupValue(value: cpVars[i]));
1884 auto llvmFuncOp = SymbolTable::lookupNearestSymbolFrom<LLVM::LLVMFuncOp>(
1885 singleOp, cast<SymbolRefAttr>((*cpFuncs)[i]));
1886 llvmCPFuncs.push_back(
1887 Elt: moduleTranslation.lookupFunction(name: llvmFuncOp.getName()));
1888 }
1889
1890 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
1891 moduleTranslation.getOpenMPBuilder()->createSingle(
1892 Loc: ompLoc, BodyGenCB: bodyCB, FiniCB: finiCB, IsNowait: singleOp.getNowait(), CPVars: llvmCPVars,
1893 CPFuncs: llvmCPFuncs);
1894
1895 if (failed(handleError(afterIP, *singleOp)))
1896 return failure();
1897
1898 builder.restoreIP(IP: *afterIP);
1899 return success();
1900}
1901
1902static bool teamsReductionContainedInDistribute(omp::TeamsOp teamsOp) {
1903 auto iface =
1904 llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(teamsOp.getOperation());
1905 // Check that all uses of the reduction block arg has the same distribute op
1906 // parent.
1907 llvm::SmallVector<mlir::Operation *> debugUses;
1908 Operation *distOp = nullptr;
1909 for (auto ra : iface.getReductionBlockArgs())
1910 for (auto &use : ra.getUses()) {
1911 auto *useOp = use.getOwner();
1912 // Ignore debug uses.
1913 if (mlir::isa<LLVM::DbgDeclareOp, LLVM::DbgValueOp>(useOp)) {
1914 debugUses.push_back(useOp);
1915 continue;
1916 }
1917
1918 auto currentDistOp = useOp->getParentOfType<omp::DistributeOp>();
1919 // Use is not inside a distribute op - return false
1920 if (!currentDistOp)
1921 return false;
1922 // Multiple distribute operations - return false
1923 Operation *currentOp = currentDistOp.getOperation();
1924 if (distOp && (distOp != currentOp))
1925 return false;
1926
1927 distOp = currentOp;
1928 }
1929
1930 // If we are going to use distribute reduction then remove any debug uses of
1931 // the reduction parameters in teamsOp. Otherwise they will be left without
1932 // any mapped value in moduleTranslation and will eventually error out.
1933 for (auto use : debugUses)
1934 use->erase();
1935 return true;
1936}
1937
1938// Convert an OpenMP Teams construct to LLVM IR using OpenMPIRBuilder
1939static LogicalResult
1940convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
1941 LLVM::ModuleTranslation &moduleTranslation) {
1942 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
1943 if (failed(checkImplementationStatus(*op)))
1944 return failure();
1945
1946 DenseMap<Value, llvm::Value *> reductionVariableMap;
1947 unsigned numReductionVars = op.getNumReductionVars();
1948 SmallVector<omp::DeclareReductionOp> reductionDecls;
1949 SmallVector<llvm::Value *> privateReductionVariables(numReductionVars);
1950 llvm::ArrayRef<bool> isByRef;
1951 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
1952 findAllocaInsertPoint(builder, moduleTranslation);
1953
1954 // Only do teams reduction if there is no distribute op that captures the
1955 // reduction instead.
1956 bool doTeamsReduction = !teamsReductionContainedInDistribute(op);
1957 if (doTeamsReduction) {
1958 isByRef = getIsByRef(op.getReductionByref());
1959
1960 assert(isByRef.size() == op.getNumReductionVars());
1961
1962 MutableArrayRef<BlockArgument> reductionArgs =
1963 llvm::cast<omp::BlockArgOpenMPOpInterface>(*op).getReductionBlockArgs();
1964
1965 collectReductionDecls(op, reductionDecls);
1966
1967 if (failed(allocAndInitializeReductionVars(
1968 op, reductionArgs, builder, moduleTranslation, allocaIP,
1969 reductionDecls, privateReductionVariables, reductionVariableMap,
1970 isByRef)))
1971 return failure();
1972 }
1973
1974 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
1975 LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame(
1976 moduleTranslation, allocaIP);
1977 builder.restoreIP(IP: codegenIP);
1978 return convertOmpOpRegions(op.getRegion(), "omp.teams.region", builder,
1979 moduleTranslation)
1980 .takeError();
1981 };
1982
1983 llvm::Value *numTeamsLower = nullptr;
1984 if (Value numTeamsLowerVar = op.getNumTeamsLower())
1985 numTeamsLower = moduleTranslation.lookupValue(value: numTeamsLowerVar);
1986
1987 llvm::Value *numTeamsUpper = nullptr;
1988 if (Value numTeamsUpperVar = op.getNumTeamsUpper())
1989 numTeamsUpper = moduleTranslation.lookupValue(value: numTeamsUpperVar);
1990
1991 llvm::Value *threadLimit = nullptr;
1992 if (Value threadLimitVar = op.getThreadLimit())
1993 threadLimit = moduleTranslation.lookupValue(value: threadLimitVar);
1994
1995 llvm::Value *ifExpr = nullptr;
1996 if (Value ifVar = op.getIfExpr())
1997 ifExpr = moduleTranslation.lookupValue(value: ifVar);
1998
1999 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2000 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2001 moduleTranslation.getOpenMPBuilder()->createTeams(
2002 ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit, ifExpr);
2003
2004 if (failed(handleError(afterIP, *op)))
2005 return failure();
2006
2007 builder.restoreIP(IP: *afterIP);
2008 if (doTeamsReduction) {
2009 // Process the reductions if required.
2010 return createReductionsAndCleanup(
2011 op, builder, moduleTranslation, allocaIP, reductionDecls,
2012 privateReductionVariables, isByRef,
2013 /*isNoWait*/ false, /*isTeamsReduction*/ true);
2014 }
2015 return success();
2016}
2017
2018static void
2019buildDependData(std::optional<ArrayAttr> dependKinds, OperandRange dependVars,
2020 LLVM::ModuleTranslation &moduleTranslation,
2021 SmallVectorImpl<llvm::OpenMPIRBuilder::DependData> &dds) {
2022 if (dependVars.empty())
2023 return;
2024 for (auto dep : llvm::zip(dependVars, dependKinds->getValue())) {
2025 llvm::omp::RTLDependenceKindTy type;
2026 switch (
2027 cast<mlir::omp::ClauseTaskDependAttr>(std::get<1>(dep)).getValue()) {
2028 case mlir::omp::ClauseTaskDepend::taskdependin:
2029 type = llvm::omp::RTLDependenceKindTy::DepIn;
2030 break;
2031 // The OpenMP runtime requires that the codegen for 'depend' clause for
2032 // 'out' dependency kind must be the same as codegen for 'depend' clause
2033 // with 'inout' dependency.
2034 case mlir::omp::ClauseTaskDepend::taskdependout:
2035 case mlir::omp::ClauseTaskDepend::taskdependinout:
2036 type = llvm::omp::RTLDependenceKindTy::DepInOut;
2037 break;
2038 case mlir::omp::ClauseTaskDepend::taskdependmutexinoutset:
2039 type = llvm::omp::RTLDependenceKindTy::DepMutexInOutSet;
2040 break;
2041 case mlir::omp::ClauseTaskDepend::taskdependinoutset:
2042 type = llvm::omp::RTLDependenceKindTy::DepInOutSet;
2043 break;
2044 };
2045 llvm::Value *depVal = moduleTranslation.lookupValue(std::get<0>(dep));
2046 llvm::OpenMPIRBuilder::DependData dd(type, depVal->getType(), depVal);
2047 dds.emplace_back(dd);
2048 }
2049}
2050
2051/// Shared implementation of a callback which adds a termiator for the new block
2052/// created for the branch taken when an openmp construct is cancelled. The
2053/// terminator is saved in \p cancelTerminators. This callback is invoked only
2054/// if there is cancellation inside of the taskgroup body.
2055/// The terminator will need to be fixed to branch to the correct block to
2056/// cleanup the construct.
2057static void
2058pushCancelFinalizationCB(SmallVectorImpl<llvm::BranchInst *> &cancelTerminators,
2059 llvm::IRBuilderBase &llvmBuilder,
2060 llvm::OpenMPIRBuilder &ompBuilder, mlir::Operation *op,
2061 llvm::omp::Directive cancelDirective) {
2062 auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
2063 llvm::IRBuilderBase::InsertPointGuard guard(llvmBuilder);
2064
2065 // ip is currently in the block branched to if cancellation occured.
2066 // We need to create a branch to terminate that block.
2067 llvmBuilder.restoreIP(IP: ip);
2068
2069 // We must still clean up the construct after cancelling it, so we need to
2070 // branch to the block that finalizes the taskgroup.
2071 // That block has not been created yet so use this block as a dummy for now
2072 // and fix this after creating the operation.
2073 cancelTerminators.push_back(Elt: llvmBuilder.CreateBr(Dest: ip.getBlock()));
2074 return llvm::Error::success();
2075 };
2076 // We have to add the cleanup to the OpenMPIRBuilder before the body gets
2077 // created in case the body contains omp.cancel (which will then expect to be
2078 // able to find this cleanup callback).
2079 ompBuilder.pushFinalizationCB(
2080 FI: {finiCB, cancelDirective, constructIsCancellable(op)});
2081}
2082
2083/// If we cancelled the construct, we should branch to the finalization block of
2084/// that construct. OMPIRBuilder structures the CFG such that the cleanup block
2085/// is immediately before the continuation block. Now this finalization has
2086/// been created we can fix the branch.
2087static void
2088popCancelFinalizationCB(const ArrayRef<llvm::BranchInst *> cancelTerminators,
2089 llvm::OpenMPIRBuilder &ompBuilder,
2090 const llvm::OpenMPIRBuilder::InsertPointTy &afterIP) {
2091 ompBuilder.popFinalizationCB();
2092 llvm::BasicBlock *constructFini = afterIP.getBlock()->getSinglePredecessor();
2093 for (llvm::BranchInst *cancelBranch : cancelTerminators) {
2094 assert(cancelBranch->getNumSuccessors() == 1 &&
2095 "cancel branch should have one target");
2096 cancelBranch->setSuccessor(idx: 0, NewSucc: constructFini);
2097 }
2098}
2099
2100namespace {
2101/// TaskContextStructManager takes care of creating and freeing a structure
2102/// containing information needed by the task body to execute.
2103class TaskContextStructManager {
2104public:
2105 TaskContextStructManager(llvm::IRBuilderBase &builder,
2106 LLVM::ModuleTranslation &moduleTranslation,
2107 MutableArrayRef<omp::PrivateClauseOp> privateDecls)
2108 : builder{builder}, moduleTranslation{moduleTranslation},
2109 privateDecls{privateDecls} {}
2110
2111 /// Creates a heap allocated struct containing space for each private
2112 /// variable. Invariant: privateVarTypes, privateDecls, and the elements of
2113 /// the structure should all have the same order (although privateDecls which
2114 /// do not read from the mold argument are skipped).
2115 void generateTaskContextStruct();
2116
2117 /// Create GEPs to access each member of the structure representing a private
2118 /// variable, adding them to llvmPrivateVars. Null values are added where
2119 /// private decls were skipped so that the ordering continues to match the
2120 /// private decls.
2121 void createGEPsToPrivateVars();
2122
2123 /// De-allocate the task context structure.
2124 void freeStructPtr();
2125
2126 MutableArrayRef<llvm::Value *> getLLVMPrivateVarGEPs() {
2127 return llvmPrivateVarGEPs;
2128 }
2129
2130 llvm::Value *getStructPtr() { return structPtr; }
2131
2132private:
2133 llvm::IRBuilderBase &builder;
2134 LLVM::ModuleTranslation &moduleTranslation;
2135 MutableArrayRef<omp::PrivateClauseOp> privateDecls;
2136
2137 /// The type of each member of the structure, in order.
2138 SmallVector<llvm::Type *> privateVarTypes;
2139
2140 /// LLVM values for each private variable, or null if that private variable is
2141 /// not included in the task context structure
2142 SmallVector<llvm::Value *> llvmPrivateVarGEPs;
2143
2144 /// A pointer to the structure containing context for this task.
2145 llvm::Value *structPtr = nullptr;
2146 /// The type of the structure
2147 llvm::Type *structTy = nullptr;
2148};
2149} // namespace
2150
2151void TaskContextStructManager::generateTaskContextStruct() {
2152 if (privateDecls.empty())
2153 return;
2154 privateVarTypes.reserve(privateDecls.size());
2155
2156 for (omp::PrivateClauseOp &privOp : privateDecls) {
2157 // Skip private variables which can safely be allocated and initialised
2158 // inside of the task
2159 if (!privOp.readsFromMold())
2160 continue;
2161 Type mlirType = privOp.getType();
2162 privateVarTypes.push_back(moduleTranslation.convertType(mlirType));
2163 }
2164
2165 structTy = llvm::StructType::get(Context&: moduleTranslation.getLLVMContext(),
2166 Elements: privateVarTypes);
2167
2168 llvm::DataLayout dataLayout =
2169 builder.GetInsertBlock()->getModule()->getDataLayout();
2170 llvm::Type *intPtrTy = builder.getIntPtrTy(DL: dataLayout);
2171 llvm::Constant *allocSize = llvm::ConstantExpr::getSizeOf(Ty: structTy);
2172
2173 // Heap allocate the structure
2174 structPtr = builder.CreateMalloc(IntPtrTy: intPtrTy, AllocTy: structTy, AllocSize: allocSize,
2175 /*ArraySize=*/nullptr, /*MallocF=*/nullptr,
2176 Name: "omp.task.context_ptr");
2177}
2178
2179void TaskContextStructManager::createGEPsToPrivateVars() {
2180 if (!structPtr) {
2181 assert(privateVarTypes.empty());
2182 return;
2183 }
2184
2185 // Create GEPs for each struct member
2186 llvmPrivateVarGEPs.clear();
2187 llvmPrivateVarGEPs.reserve(privateDecls.size());
2188 llvm::Value *zero = builder.getInt32(C: 0);
2189 unsigned i = 0;
2190 for (auto privDecl : privateDecls) {
2191 if (!privDecl.readsFromMold()) {
2192 // Handle this inside of the task so we don't pass unnessecary vars in
2193 llvmPrivateVarGEPs.push_back(nullptr);
2194 continue;
2195 }
2196 llvm::Value *iVal = builder.getInt32(i);
2197 llvm::Value *gep = builder.CreateGEP(structTy, structPtr, {zero, iVal});
2198 llvmPrivateVarGEPs.push_back(gep);
2199 i += 1;
2200 }
2201}
2202
2203void TaskContextStructManager::freeStructPtr() {
2204 if (!structPtr)
2205 return;
2206
2207 llvm::IRBuilderBase::InsertPointGuard guard{builder};
2208 // Ensure we don't put the call to free() after the terminator
2209 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
2210 builder.CreateFree(Source: structPtr);
2211}
2212
2213/// Converts an OpenMP task construct into LLVM IR using OpenMPIRBuilder.
2214static LogicalResult
2215convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
2216 LLVM::ModuleTranslation &moduleTranslation) {
2217 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2218 if (failed(checkImplementationStatus(*taskOp)))
2219 return failure();
2220
2221 PrivateVarsInfo privateVarsInfo(taskOp);
2222 TaskContextStructManager taskStructMgr{builder, moduleTranslation,
2223 privateVarsInfo.privatizers};
2224
2225 // Allocate and copy private variables before creating the task. This avoids
2226 // accessing invalid memory if (after this scope ends) the private variables
2227 // are initialized from host variables or if the variables are copied into
2228 // from host variables (firstprivate). The insertion point is just before
2229 // where the code for creating and scheduling the task will go. That puts this
2230 // code outside of the outlined task region, which is what we want because
2231 // this way the initialization and copy regions are executed immediately while
2232 // the host variable data are still live.
2233
2234 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2235 findAllocaInsertPoint(builder, moduleTranslation);
2236
2237 // Not using splitBB() because that requires the current block to have a
2238 // terminator.
2239 assert(builder.GetInsertPoint() == builder.GetInsertBlock()->end());
2240 llvm::BasicBlock *taskStartBlock = llvm::BasicBlock::Create(
2241 Context&: builder.getContext(), Name: "omp.task.start",
2242 /*Parent=*/builder.GetInsertBlock()->getParent());
2243 llvm::Instruction *branchToTaskStartBlock = builder.CreateBr(Dest: taskStartBlock);
2244 builder.SetInsertPoint(branchToTaskStartBlock);
2245
2246 // Now do this again to make the initialization and copy blocks
2247 llvm::BasicBlock *copyBlock =
2248 splitBB(Builder&: builder, /*CreateBranch=*/true, Name: "omp.private.copy");
2249 llvm::BasicBlock *initBlock =
2250 splitBB(Builder&: builder, /*CreateBranch=*/true, Name: "omp.private.init");
2251
2252 // Now the control flow graph should look like
2253 // starter_block:
2254 // <---- where we started when convertOmpTaskOp was called
2255 // br %omp.private.init
2256 // omp.private.init:
2257 // br %omp.private.copy
2258 // omp.private.copy:
2259 // br %omp.task.start
2260 // omp.task.start:
2261 // <---- where we want the insertion point to be when we call createTask()
2262
2263 // Save the alloca insertion point on ModuleTranslation stack for use in
2264 // nested regions.
2265 LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame(
2266 moduleTranslation, allocaIP);
2267
2268 // Allocate and initialize private variables
2269 builder.SetInsertPoint(initBlock->getTerminator());
2270
2271 // Create task variable structure
2272 taskStructMgr.generateTaskContextStruct();
2273 // GEPs so that we can initialize the variables. Don't use these GEPs inside
2274 // of the body otherwise it will be the GEP not the struct which is fowarded
2275 // to the outlined function. GEPs forwarded in this way are passed in a
2276 // stack-allocated (by OpenMPIRBuilder) structure which is not safe for tasks
2277 // which may not be executed until after the current stack frame goes out of
2278 // scope.
2279 taskStructMgr.createGEPsToPrivateVars();
2280
2281 for (auto [privDecl, mlirPrivVar, blockArg, llvmPrivateVarAlloc] :
2282 llvm::zip_equal(privateVarsInfo.privatizers, privateVarsInfo.mlirVars,
2283 privateVarsInfo.blockArgs,
2284 taskStructMgr.getLLVMPrivateVarGEPs())) {
2285 // To be handled inside the task.
2286 if (!privDecl.readsFromMold())
2287 continue;
2288 assert(llvmPrivateVarAlloc &&
2289 "reads from mold so shouldn't have been skipped");
2290
2291 llvm::Expected<llvm::Value *> privateVarOrErr =
2292 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2293 blockArg, llvmPrivateVarAlloc, initBlock);
2294 if (!privateVarOrErr)
2295 return handleError(privateVarOrErr, *taskOp.getOperation());
2296
2297 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2298 builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
2299
2300 // TODO: this is a bit of a hack for Fortran character boxes.
2301 // Character boxes are passed by value into the init region and then the
2302 // initialized character box is yielded by value. Here we need to store the
2303 // yielded value into the private allocation, and load the private
2304 // allocation to match the type expected by region block arguments.
2305 if ((privateVarOrErr.get() != llvmPrivateVarAlloc) &&
2306 !mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2307 builder.CreateStore(privateVarOrErr.get(), llvmPrivateVarAlloc);
2308 // Load it so we have the value pointed to by the GEP
2309 llvmPrivateVarAlloc = builder.CreateLoad(privateVarOrErr.get()->getType(),
2310 llvmPrivateVarAlloc);
2311 }
2312 assert(llvmPrivateVarAlloc->getType() ==
2313 moduleTranslation.convertType(blockArg.getType()));
2314
2315 // Mapping blockArg -> llvmPrivateVarAlloc is done inside the body callback
2316 // so that OpenMPIRBuilder doesn't try to pass each GEP address through a
2317 // stack allocated structure.
2318 }
2319
2320 // firstprivate copy region
2321 setInsertPointForPossiblyEmptyBlock(builder, block: copyBlock);
2322 if (failed(copyFirstPrivateVars(
2323 taskOp, builder, moduleTranslation, privateVarsInfo.mlirVars,
2324 taskStructMgr.getLLVMPrivateVarGEPs(), privateVarsInfo.privatizers,
2325 taskOp.getPrivateNeedsBarrier())))
2326 return llvm::failure();
2327
2328 // Set up for call to createTask()
2329 builder.SetInsertPoint(taskStartBlock);
2330
2331 auto bodyCB = [&](InsertPointTy allocaIP,
2332 InsertPointTy codegenIP) -> llvm::Error {
2333 // Save the alloca insertion point on ModuleTranslation stack for use in
2334 // nested regions.
2335 LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame(
2336 moduleTranslation, allocaIP);
2337
2338 // translate the body of the task:
2339 builder.restoreIP(IP: codegenIP);
2340
2341 llvm::BasicBlock *privInitBlock = nullptr;
2342 privateVarsInfo.llvmVars.resize(N: privateVarsInfo.blockArgs.size());
2343 for (auto [i, zip] : llvm::enumerate(llvm::zip_equal(
2344 privateVarsInfo.blockArgs, privateVarsInfo.privatizers,
2345 privateVarsInfo.mlirVars))) {
2346 auto [blockArg, privDecl, mlirPrivVar] = zip;
2347 // This is handled before the task executes
2348 if (privDecl.readsFromMold())
2349 continue;
2350
2351 llvm::IRBuilderBase::InsertPointGuard guard(builder);
2352 llvm::Type *llvmAllocType =
2353 moduleTranslation.convertType(privDecl.getType());
2354 builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
2355 llvm::Value *llvmPrivateVar = builder.CreateAlloca(
2356 llvmAllocType, /*ArraySize=*/nullptr, "omp.private.alloc");
2357
2358 llvm::Expected<llvm::Value *> privateVarOrError =
2359 initPrivateVar(builder, moduleTranslation, privDecl, mlirPrivVar,
2360 blockArg, llvmPrivateVar, privInitBlock);
2361 if (!privateVarOrError)
2362 return privateVarOrError.takeError();
2363 moduleTranslation.mapValue(blockArg, privateVarOrError.get());
2364 privateVarsInfo.llvmVars[i] = privateVarOrError.get();
2365 }
2366
2367 taskStructMgr.createGEPsToPrivateVars();
2368 for (auto [i, llvmPrivVar] :
2369 llvm::enumerate(taskStructMgr.getLLVMPrivateVarGEPs())) {
2370 if (!llvmPrivVar) {
2371 assert(privateVarsInfo.llvmVars[i] &&
2372 "This is added in the loop above");
2373 continue;
2374 }
2375 privateVarsInfo.llvmVars[i] = llvmPrivVar;
2376 }
2377
2378 // Find and map the addresses of each variable within the task context
2379 // structure
2380 for (auto [blockArg, llvmPrivateVar, privateDecl] :
2381 llvm::zip_equal(privateVarsInfo.blockArgs, privateVarsInfo.llvmVars,
2382 privateVarsInfo.privatizers)) {
2383 // This was handled above.
2384 if (!privateDecl.readsFromMold())
2385 continue;
2386 // Fix broken pass-by-value case for Fortran character boxes
2387 if (!mlir::isa<LLVM::LLVMPointerType>(blockArg.getType())) {
2388 llvmPrivateVar = builder.CreateLoad(
2389 moduleTranslation.convertType(blockArg.getType()), llvmPrivateVar);
2390 }
2391 assert(llvmPrivateVar->getType() ==
2392 moduleTranslation.convertType(blockArg.getType()));
2393 moduleTranslation.mapValue(blockArg, llvmPrivateVar);
2394 }
2395
2396 auto continuationBlockOrError = convertOmpOpRegions(
2397 taskOp.getRegion(), "omp.task.region", builder, moduleTranslation);
2398 if (failed(handleError(continuationBlockOrError, *taskOp)))
2399 return llvm::make_error<PreviouslyReportedError>();
2400
2401 builder.SetInsertPoint(continuationBlockOrError.get()->getTerminator());
2402
2403 if (failed(cleanupPrivateVars(builder, moduleTranslation, taskOp.getLoc(),
2404 privateVarsInfo.llvmVars,
2405 privateVarsInfo.privatizers)))
2406 return llvm::make_error<PreviouslyReportedError>();
2407
2408 // Free heap allocated task context structure at the end of the task.
2409 taskStructMgr.freeStructPtr();
2410
2411 return llvm::Error::success();
2412 };
2413
2414 llvm::OpenMPIRBuilder &ompBuilder = *moduleTranslation.getOpenMPBuilder();
2415 SmallVector<llvm::BranchInst *> cancelTerminators;
2416 // The directive to match here is OMPD_taskgroup because it is the taskgroup
2417 // which is canceled. This is handled here because it is the task's cleanup
2418 // block which should be branched to.
2419 pushCancelFinalizationCB(cancelTerminators, builder, ompBuilder, taskOp,
2420 llvm::omp::Directive::OMPD_taskgroup);
2421
2422 SmallVector<llvm::OpenMPIRBuilder::DependData> dds;
2423 buildDependData(taskOp.getDependKinds(), taskOp.getDependVars(),
2424 moduleTranslation, dds);
2425
2426 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2427 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2428 moduleTranslation.getOpenMPBuilder()->createTask(
2429 Loc: ompLoc, AllocaIP: allocaIP, BodyGenCB: bodyCB, Tied: !taskOp.getUntied(),
2430 Final: moduleTranslation.lookupValue(value: taskOp.getFinal()),
2431 IfCondition: moduleTranslation.lookupValue(value: taskOp.getIfExpr()), Dependencies: dds,
2432 Mergeable: taskOp.getMergeable(),
2433 EventHandle: moduleTranslation.lookupValue(value: taskOp.getEventHandle()),
2434 Priority: moduleTranslation.lookupValue(value: taskOp.getPriority()));
2435
2436 if (failed(handleError(afterIP, *taskOp)))
2437 return failure();
2438
2439 // Set the correct branch target for task cancellation
2440 popCancelFinalizationCB(cancelTerminators, ompBuilder, afterIP: afterIP.get());
2441
2442 builder.restoreIP(IP: *afterIP);
2443 return success();
2444}
2445
2446/// Converts an OpenMP taskgroup construct into LLVM IR using OpenMPIRBuilder.
2447static LogicalResult
2448convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
2449 LLVM::ModuleTranslation &moduleTranslation) {
2450 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2451 if (failed(checkImplementationStatus(*tgOp)))
2452 return failure();
2453
2454 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
2455 builder.restoreIP(IP: codegenIP);
2456 return convertOmpOpRegions(tgOp.getRegion(), "omp.taskgroup.region",
2457 builder, moduleTranslation)
2458 .takeError();
2459 };
2460
2461 InsertPointTy allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
2462 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2463 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2464 moduleTranslation.getOpenMPBuilder()->createTaskgroup(ompLoc, allocaIP,
2465 bodyCB);
2466
2467 if (failed(handleError(afterIP, *tgOp)))
2468 return failure();
2469
2470 builder.restoreIP(IP: *afterIP);
2471 return success();
2472}
2473
2474static LogicalResult
2475convertOmpTaskwaitOp(omp::TaskwaitOp twOp, llvm::IRBuilderBase &builder,
2476 LLVM::ModuleTranslation &moduleTranslation) {
2477 if (failed(checkImplementationStatus(*twOp)))
2478 return failure();
2479
2480 moduleTranslation.getOpenMPBuilder()->createTaskwait(Loc: builder.saveIP());
2481 return success();
2482}
2483
2484/// Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
2485static LogicalResult
2486convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
2487 LLVM::ModuleTranslation &moduleTranslation) {
2488 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
2489 auto wsloopOp = cast<omp::WsloopOp>(opInst);
2490 if (failed(Result: checkImplementationStatus(op&: opInst)))
2491 return failure();
2492
2493 auto loopOp = cast<omp::LoopNestOp>(wsloopOp.getWrappedLoop());
2494 llvm::ArrayRef<bool> isByRef = getIsByRef(wsloopOp.getReductionByref());
2495 assert(isByRef.size() == wsloopOp.getNumReductionVars());
2496
2497 // Static is the default.
2498 auto schedule =
2499 wsloopOp.getScheduleKind().value_or(omp::ClauseScheduleKind::Static);
2500
2501 // Find the loop configuration.
2502 llvm::Value *step = moduleTranslation.lookupValue(value: loopOp.getLoopSteps()[0]);
2503 llvm::Type *ivType = step->getType();
2504 llvm::Value *chunk = nullptr;
2505 if (wsloopOp.getScheduleChunk()) {
2506 llvm::Value *chunkVar =
2507 moduleTranslation.lookupValue(value: wsloopOp.getScheduleChunk());
2508 chunk = builder.CreateSExtOrTrunc(V: chunkVar, DestTy: ivType);
2509 }
2510
2511 PrivateVarsInfo privateVarsInfo(wsloopOp);
2512
2513 SmallVector<omp::DeclareReductionOp> reductionDecls;
2514 collectReductionDecls(wsloopOp, reductionDecls);
2515 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2516 findAllocaInsertPoint(builder, moduleTranslation);
2517
2518 SmallVector<llvm::Value *> privateReductionVariables(
2519 wsloopOp.getNumReductionVars());
2520
2521 llvm::Expected<llvm::BasicBlock *> afterAllocas = allocatePrivateVars(
2522 builder, moduleTranslation, privateVarsInfo, allocaIP);
2523 if (handleError(result&: afterAllocas, op&: opInst).failed())
2524 return failure();
2525
2526 DenseMap<Value, llvm::Value *> reductionVariableMap;
2527
2528 MutableArrayRef<BlockArgument> reductionArgs =
2529 cast<omp::BlockArgOpenMPOpInterface>(opInst).getReductionBlockArgs();
2530
2531 SmallVector<DeferredStore> deferredStores;
2532
2533 if (failed(allocReductionVars(wsloopOp, reductionArgs, builder,
2534 moduleTranslation, allocaIP, reductionDecls,
2535 privateReductionVariables, reductionVariableMap,
2536 deferredStores, isByRef)))
2537 return failure();
2538
2539 if (handleError(error: initPrivateVars(builder, moduleTranslation, privateVarsInfo),
2540 op&: opInst)
2541 .failed())
2542 return failure();
2543
2544 if (failed(copyFirstPrivateVars(
2545 wsloopOp, builder, moduleTranslation, privateVarsInfo.mlirVars,
2546 privateVarsInfo.llvmVars, privateVarsInfo.privatizers,
2547 wsloopOp.getPrivateNeedsBarrier())))
2548 return failure();
2549
2550 assert(afterAllocas.get()->getSinglePredecessor());
2551 if (failed(initReductionVars(wsloopOp, reductionArgs, builder,
2552 moduleTranslation,
2553 afterAllocas.get()->getSinglePredecessor(),
2554 reductionDecls, privateReductionVariables,
2555 reductionVariableMap, isByRef, deferredStores)))
2556 return failure();
2557
2558 // TODO: Handle doacross loops when the ordered clause has a parameter.
2559 bool isOrdered = wsloopOp.getOrdered().has_value();
2560 std::optional<omp::ScheduleModifier> scheduleMod = wsloopOp.getScheduleMod();
2561 bool isSimd = wsloopOp.getScheduleSimd();
2562 bool loopNeedsBarrier = !wsloopOp.getNowait();
2563
2564 // The only legal way for the direct parent to be omp.distribute is that this
2565 // represents 'distribute parallel do'. Otherwise, this is a regular
2566 // worksharing loop.
2567 llvm::omp::WorksharingLoopType workshareLoopType =
2568 llvm::isa_and_present<omp::DistributeOp>(opInst.getParentOp())
2569 ? llvm::omp::WorksharingLoopType::DistributeForStaticLoop
2570 : llvm::omp::WorksharingLoopType::ForStaticLoop;
2571
2572 SmallVector<llvm::BranchInst *> cancelTerminators;
2573 pushCancelFinalizationCB(cancelTerminators, builder, *ompBuilder, wsloopOp,
2574 llvm::omp::Directive::OMPD_for);
2575
2576 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2577
2578 // Initialize linear variables and linear step
2579 LinearClauseProcessor linearClauseProcessor;
2580 if (wsloopOp.getLinearVars().size()) {
2581 for (mlir::Value linearVar : wsloopOp.getLinearVars())
2582 linearClauseProcessor.createLinearVar(builder, moduleTranslation,
2583 linearVar);
2584 for (mlir::Value linearStep : wsloopOp.getLinearStepVars())
2585 linearClauseProcessor.initLinearStep(moduleTranslation, linearStep);
2586 }
2587
2588 llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
2589 wsloopOp.getRegion(), "omp.wsloop.region", builder, moduleTranslation);
2590
2591 if (failed(Result: handleError(result&: regionBlock, op&: opInst)))
2592 return failure();
2593
2594 llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
2595
2596 // Emit Initialization and Update IR for linear variables
2597 if (wsloopOp.getLinearVars().size()) {
2598 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
2599 linearClauseProcessor.initLinearVar(builder, moduleTranslation,
2600 loopPreHeader: loopInfo->getPreheader());
2601 if (failed(handleError(afterBarrierIP, *loopOp)))
2602 return failure();
2603 builder.restoreIP(IP: *afterBarrierIP);
2604 linearClauseProcessor.updateLinearVar(builder, loopBody: loopInfo->getBody(),
2605 loopInductionVar: loopInfo->getIndVar());
2606 linearClauseProcessor.outlineLinearFinalizationBB(builder,
2607 loopExit: loopInfo->getExit());
2608 }
2609
2610 builder.SetInsertPoint(TheBB: *regionBlock, IP: (*regionBlock)->begin());
2611 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
2612 ompBuilder->applyWorkshareLoop(
2613 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
2614 convertToScheduleKind(schedule), chunk, isSimd,
2615 scheduleMod == omp::ScheduleModifier::monotonic,
2616 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
2617 workshareLoopType);
2618
2619 if (failed(Result: handleError(result&: wsloopIP, op&: opInst)))
2620 return failure();
2621
2622 // Emit finalization and in-place rewrites for linear vars.
2623 if (wsloopOp.getLinearVars().size()) {
2624 llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP();
2625 assert(loopInfo->getLastIter() &&
2626 "`lastiter` in CanonicalLoopInfo is nullptr");
2627 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP =
2628 linearClauseProcessor.finalizeLinearVar(builder, moduleTranslation,
2629 lastIter: loopInfo->getLastIter());
2630 if (failed(handleError(afterBarrierIP, *loopOp)))
2631 return failure();
2632 for (size_t index = 0; index < wsloopOp.getLinearVars().size(); index++)
2633 linearClauseProcessor.rewriteInPlace(builder, BBName: "omp.loop_nest.region",
2634 varIndex: index);
2635 builder.restoreIP(IP: oldIP);
2636 }
2637
2638 // Set the correct branch target for task cancellation
2639 popCancelFinalizationCB(cancelTerminators, ompBuilder&: *ompBuilder, afterIP: wsloopIP.get());
2640
2641 // Process the reductions if required.
2642 if (failed(createReductionsAndCleanup(
2643 wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
2644 privateReductionVariables, isByRef, wsloopOp.getNowait(),
2645 /*isTeamsReduction=*/false)))
2646 return failure();
2647
2648 return cleanupPrivateVars(builder, moduleTranslation, wsloopOp.getLoc(),
2649 privateVarsInfo.llvmVars,
2650 privateVarsInfo.privatizers);
2651}
2652
2653/// Converts the OpenMP parallel operation to LLVM IR.
2654static LogicalResult
2655convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
2656 LLVM::ModuleTranslation &moduleTranslation) {
2657 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
2658 ArrayRef<bool> isByRef = getIsByRef(opInst.getReductionByref());
2659 assert(isByRef.size() == opInst.getNumReductionVars());
2660 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
2661
2662 if (failed(checkImplementationStatus(*opInst)))
2663 return failure();
2664
2665 PrivateVarsInfo privateVarsInfo(opInst);
2666
2667 // Collect reduction declarations
2668 SmallVector<omp::DeclareReductionOp> reductionDecls;
2669 collectReductionDecls(opInst, reductionDecls);
2670 SmallVector<llvm::Value *> privateReductionVariables(
2671 opInst.getNumReductionVars());
2672 SmallVector<DeferredStore> deferredStores;
2673
2674 auto bodyGenCB = [&](InsertPointTy allocaIP,
2675 InsertPointTy codeGenIP) -> llvm::Error {
2676 llvm::Expected<llvm::BasicBlock *> afterAllocas = allocatePrivateVars(
2677 builder, moduleTranslation, privateVarsInfo, allocaIP);
2678 if (handleError(afterAllocas, *opInst).failed())
2679 return llvm::make_error<PreviouslyReportedError>();
2680
2681 // Allocate reduction vars
2682 DenseMap<Value, llvm::Value *> reductionVariableMap;
2683
2684 MutableArrayRef<BlockArgument> reductionArgs =
2685 cast<omp::BlockArgOpenMPOpInterface>(*opInst).getReductionBlockArgs();
2686
2687 allocaIP =
2688 InsertPointTy(allocaIP.getBlock(),
2689 allocaIP.getBlock()->getTerminator()->getIterator());
2690
2691 if (failed(allocReductionVars(
2692 opInst, reductionArgs, builder, moduleTranslation, allocaIP,
2693 reductionDecls, privateReductionVariables, reductionVariableMap,
2694 deferredStores, isByRef)))
2695 return llvm::make_error<PreviouslyReportedError>();
2696
2697 assert(afterAllocas.get()->getSinglePredecessor());
2698 builder.restoreIP(IP: codeGenIP);
2699
2700 if (handleError(
2701 initPrivateVars(builder, moduleTranslation, privateVarsInfo),
2702 *opInst)
2703 .failed())
2704 return llvm::make_error<PreviouslyReportedError>();
2705
2706 if (failed(copyFirstPrivateVars(
2707 opInst, builder, moduleTranslation, privateVarsInfo.mlirVars,
2708 privateVarsInfo.llvmVars, privateVarsInfo.privatizers,
2709 opInst.getPrivateNeedsBarrier())))
2710 return llvm::make_error<PreviouslyReportedError>();
2711
2712 if (failed(
2713 initReductionVars(opInst, reductionArgs, builder, moduleTranslation,
2714 afterAllocas.get()->getSinglePredecessor(),
2715 reductionDecls, privateReductionVariables,
2716 reductionVariableMap, isByRef, deferredStores)))
2717 return llvm::make_error<PreviouslyReportedError>();
2718
2719 // Save the alloca insertion point on ModuleTranslation stack for use in
2720 // nested regions.
2721 LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame(
2722 moduleTranslation, allocaIP);
2723
2724 // ParallelOp has only one region associated with it.
2725 llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
2726 opInst.getRegion(), "omp.par.region", builder, moduleTranslation);
2727 if (!regionBlock)
2728 return regionBlock.takeError();
2729
2730 // Process the reductions if required.
2731 if (opInst.getNumReductionVars() > 0) {
2732 // Collect reduction info
2733 SmallVector<OwningReductionGen> owningReductionGens;
2734 SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
2735 SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
2736 collectReductionInfo(opInst, builder, moduleTranslation, reductionDecls,
2737 owningReductionGens, owningAtomicReductionGens,
2738 privateReductionVariables, reductionInfos);
2739
2740 // Move to region cont block
2741 builder.SetInsertPoint((*regionBlock)->getTerminator());
2742
2743 // Generate reductions from info
2744 llvm::UnreachableInst *tempTerminator = builder.CreateUnreachable();
2745 builder.SetInsertPoint(tempTerminator);
2746
2747 llvm::OpenMPIRBuilder::InsertPointOrErrorTy contInsertPoint =
2748 ompBuilder->createReductions(
2749 Loc: builder.saveIP(), AllocaIP: allocaIP, ReductionInfos: reductionInfos, IsByRef: isByRef,
2750 /*IsNoWait=*/false, /*IsTeamsReduction=*/false);
2751 if (!contInsertPoint)
2752 return contInsertPoint.takeError();
2753
2754 if (!contInsertPoint->getBlock())
2755 return llvm::make_error<PreviouslyReportedError>();
2756
2757 tempTerminator->eraseFromParent();
2758 builder.restoreIP(IP: *contInsertPoint);
2759 }
2760
2761 return llvm::Error::success();
2762 };
2763
2764 auto privCB = [](InsertPointTy allocaIP, InsertPointTy codeGenIP,
2765 llvm::Value &, llvm::Value &val, llvm::Value *&replVal) {
2766 // tell OpenMPIRBuilder not to do anything. We handled Privatisation in
2767 // bodyGenCB.
2768 replVal = &val;
2769 return codeGenIP;
2770 };
2771
2772 // TODO: Perform finalization actions for variables. This has to be
2773 // called for variables which have destructors/finalizers.
2774 auto finiCB = [&](InsertPointTy codeGenIP) -> llvm::Error {
2775 InsertPointTy oldIP = builder.saveIP();
2776 builder.restoreIP(IP: codeGenIP);
2777
2778 // if the reduction has a cleanup region, inline it here to finalize the
2779 // reduction variables
2780 SmallVector<Region *> reductionCleanupRegions;
2781 llvm::transform(reductionDecls, std::back_inserter(x&: reductionCleanupRegions),
2782 [](omp::DeclareReductionOp reductionDecl) {
2783 return &reductionDecl.getCleanupRegion();
2784 });
2785 if (failed(Result: inlineOmpRegionCleanup(
2786 cleanupRegions&: reductionCleanupRegions, privateVariables: privateReductionVariables,
2787 moduleTranslation, builder, regionName: "omp.reduction.cleanup")))
2788 return llvm::createStringError(
2789 Fmt: "failed to inline `cleanup` region of `omp.declare_reduction`");
2790
2791 if (failed(cleanupPrivateVars(builder, moduleTranslation, opInst.getLoc(),
2792 privateVarsInfo.llvmVars,
2793 privateVarsInfo.privatizers)))
2794 return llvm::make_error<PreviouslyReportedError>();
2795
2796 builder.restoreIP(IP: oldIP);
2797 return llvm::Error::success();
2798 };
2799
2800 llvm::Value *ifCond = nullptr;
2801 if (auto ifVar = opInst.getIfExpr())
2802 ifCond = moduleTranslation.lookupValue(value: ifVar);
2803 llvm::Value *numThreads = nullptr;
2804 if (auto numThreadsVar = opInst.getNumThreads())
2805 numThreads = moduleTranslation.lookupValue(value: numThreadsVar);
2806 auto pbKind = llvm::omp::OMP_PROC_BIND_default;
2807 if (auto bind = opInst.getProcBindKind())
2808 pbKind = getProcBindKind(*bind);
2809 bool isCancellable = constructIsCancellable(opInst);
2810
2811 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2812 findAllocaInsertPoint(builder, moduleTranslation);
2813 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2814
2815 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2816 ompBuilder->createParallel(ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
2817 ifCond, numThreads, pbKind, isCancellable);
2818
2819 if (failed(handleError(afterIP, *opInst)))
2820 return failure();
2821
2822 builder.restoreIP(IP: *afterIP);
2823 return success();
2824}
2825
2826/// Convert Order attribute to llvm::omp::OrderKind.
2827static llvm::omp::OrderKind
2828convertOrderKind(std::optional<omp::ClauseOrderKind> o) {
2829 if (!o)
2830 return llvm::omp::OrderKind::OMP_ORDER_unknown;
2831 switch (*o) {
2832 case omp::ClauseOrderKind::Concurrent:
2833 return llvm::omp::OrderKind::OMP_ORDER_concurrent;
2834 }
2835 llvm_unreachable("Unknown ClauseOrderKind kind");
2836}
2837
2838/// Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
2839static LogicalResult
2840convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
2841 LLVM::ModuleTranslation &moduleTranslation) {
2842 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
2843 auto simdOp = cast<omp::SimdOp>(opInst);
2844
2845 // TODO: Replace this with proper composite translation support.
2846 // Currently, simd information on composite constructs is ignored, so e.g.
2847 // 'do/for simd' will be treated the same as a standalone 'do/for'. This is
2848 // allowed by the spec, since it's equivalent to using a SIMD length of 1.
2849 if (simdOp.isComposite()) {
2850 if (failed(convertIgnoredWrapper(simdOp, moduleTranslation)))
2851 return failure();
2852
2853 return inlineConvertOmpRegions(simdOp.getRegion(), "omp.simd.region",
2854 builder, moduleTranslation);
2855 }
2856
2857 if (failed(Result: checkImplementationStatus(op&: opInst)))
2858 return failure();
2859
2860 PrivateVarsInfo privateVarsInfo(simdOp);
2861
2862 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2863 findAllocaInsertPoint(builder, moduleTranslation);
2864
2865 llvm::Expected<llvm::BasicBlock *> afterAllocas = allocatePrivateVars(
2866 builder, moduleTranslation, privateVarsInfo, allocaIP);
2867 if (handleError(result&: afterAllocas, op&: opInst).failed())
2868 return failure();
2869
2870 if (handleError(error: initPrivateVars(builder, moduleTranslation, privateVarsInfo),
2871 op&: opInst)
2872 .failed())
2873 return failure();
2874
2875 llvm::ConstantInt *simdlen = nullptr;
2876 if (std::optional<uint64_t> simdlenVar = simdOp.getSimdlen())
2877 simdlen = builder.getInt64(C: simdlenVar.value());
2878
2879 llvm::ConstantInt *safelen = nullptr;
2880 if (std::optional<uint64_t> safelenVar = simdOp.getSafelen())
2881 safelen = builder.getInt64(C: safelenVar.value());
2882
2883 llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
2884 llvm::omp::OrderKind order = convertOrderKind(simdOp.getOrder());
2885
2886 llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
2887 std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
2888 mlir::OperandRange operands = simdOp.getAlignedVars();
2889 for (size_t i = 0; i < operands.size(); ++i) {
2890 llvm::Value *alignment = nullptr;
2891 llvm::Value *llvmVal = moduleTranslation.lookupValue(value: operands[i]);
2892 llvm::Type *ty = llvmVal->getType();
2893
2894 auto intAttr = cast<IntegerAttr>((*alignmentValues)[i]);
2895 alignment = builder.getInt64(C: intAttr.getInt());
2896 assert(ty->isPointerTy() && "Invalid type for aligned variable");
2897 assert(alignment && "Invalid alignment value");
2898 auto curInsert = builder.saveIP();
2899 builder.SetInsertPoint(sourceBlock);
2900 llvmVal = builder.CreateLoad(Ty: ty, Ptr: llvmVal);
2901 builder.restoreIP(IP: curInsert);
2902 alignedVars[llvmVal] = alignment;
2903 }
2904
2905 llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
2906 simdOp.getRegion(), "omp.simd.region", builder, moduleTranslation);
2907
2908 if (failed(Result: handleError(result&: regionBlock, op&: opInst)))
2909 return failure();
2910
2911 builder.SetInsertPoint(TheBB: *regionBlock, IP: (*regionBlock)->begin());
2912 llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
2913 ompBuilder->applySimd(loopInfo, alignedVars,
2914 simdOp.getIfExpr()
2915 ? moduleTranslation.lookupValue(value: simdOp.getIfExpr())
2916 : nullptr,
2917 order, simdlen, safelen);
2918
2919 return cleanupPrivateVars(builder, moduleTranslation, simdOp.getLoc(),
2920 privateVarsInfo.llvmVars,
2921 privateVarsInfo.privatizers);
2922}
2923
2924/// Converts an OpenMP loop nest into LLVM IR using OpenMPIRBuilder.
2925static LogicalResult
2926convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder,
2927 LLVM::ModuleTranslation &moduleTranslation) {
2928 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
2929 auto loopOp = cast<omp::LoopNestOp>(opInst);
2930
2931 // Set up the source location value for OpenMP runtime.
2932 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2933
2934 // Generator of the canonical loop body.
2935 SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
2936 SmallVector<llvm::OpenMPIRBuilder::InsertPointTy> bodyInsertPoints;
2937 auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip,
2938 llvm::Value *iv) -> llvm::Error {
2939 // Make sure further conversions know about the induction variable.
2940 moduleTranslation.mapValue(
2941 loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
2942
2943 // Capture the body insertion point for use in nested loops. BodyIP of the
2944 // CanonicalLoopInfo always points to the beginning of the entry block of
2945 // the body.
2946 bodyInsertPoints.push_back(Elt: ip);
2947
2948 if (loopInfos.size() != loopOp.getNumLoops() - 1)
2949 return llvm::Error::success();
2950
2951 // Convert the body of the loop.
2952 builder.restoreIP(IP: ip);
2953 llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
2954 loopOp.getRegion(), "omp.loop_nest.region", builder, moduleTranslation);
2955 if (!regionBlock)
2956 return regionBlock.takeError();
2957
2958 builder.SetInsertPoint(TheBB: *regionBlock, IP: (*regionBlock)->begin());
2959 return llvm::Error::success();
2960 };
2961
2962 // Delegate actual loop construction to the OpenMP IRBuilder.
2963 // TODO: this currently assumes omp.loop_nest is semantically similar to SCF
2964 // loop, i.e. it has a positive step, uses signed integer semantics.
2965 // Reconsider this code when the nested loop operation clearly supports more
2966 // cases.
2967 for (unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
2968 llvm::Value *lowerBound =
2969 moduleTranslation.lookupValue(value: loopOp.getLoopLowerBounds()[i]);
2970 llvm::Value *upperBound =
2971 moduleTranslation.lookupValue(value: loopOp.getLoopUpperBounds()[i]);
2972 llvm::Value *step = moduleTranslation.lookupValue(value: loopOp.getLoopSteps()[i]);
2973
2974 // Make sure loop trip count are emitted in the preheader of the outermost
2975 // loop at the latest so that they are all available for the new collapsed
2976 // loop will be created below.
2977 llvm::OpenMPIRBuilder::LocationDescription loc = ompLoc;
2978 llvm::OpenMPIRBuilder::InsertPointTy computeIP = ompLoc.IP;
2979 if (i != 0) {
2980 loc = llvm::OpenMPIRBuilder::LocationDescription(bodyInsertPoints.back(),
2981 ompLoc.DL);
2982 computeIP = loopInfos.front()->getPreheaderIP();
2983 }
2984
2985 llvm::Expected<llvm::CanonicalLoopInfo *> loopResult =
2986 ompBuilder->createCanonicalLoop(
2987 loc, bodyGen, lowerBound, upperBound, step,
2988 /*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP);
2989
2990 if (failed(handleError(loopResult, *loopOp)))
2991 return failure();
2992
2993 loopInfos.push_back(Elt: *loopResult);
2994 }
2995
2996 // Collapse loops. Store the insertion point because LoopInfos may get
2997 // invalidated.
2998 llvm::OpenMPIRBuilder::InsertPointTy afterIP =
2999 loopInfos.front()->getAfterIP();
3000
3001 // Update the stack frame created for this loop to point to the resulting loop
3002 // after applying transformations.
3003 moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
3004 callback: [&](OpenMPLoopInfoStackFrame &frame) {
3005 frame.loopInfo = ompBuilder->collapseLoops(DL: ompLoc.DL, Loops: loopInfos, ComputeIP: {});
3006 return WalkResult::interrupt();
3007 });
3008
3009 // Continue building IR after the loop. Note that the LoopInfo returned by
3010 // `collapseLoops` points inside the outermost loop and is intended for
3011 // potential further loop transformations. Use the insertion point stored
3012 // before collapsing loops instead.
3013 builder.restoreIP(IP: afterIP);
3014 return success();
3015}
3016
3017/// Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
3018static llvm::AtomicOrdering
3019convertAtomicOrdering(std::optional<omp::ClauseMemoryOrderKind> ao) {
3020 if (!ao)
3021 return llvm::AtomicOrdering::Monotonic; // Default Memory Ordering
3022
3023 switch (*ao) {
3024 case omp::ClauseMemoryOrderKind::Seq_cst:
3025 return llvm::AtomicOrdering::SequentiallyConsistent;
3026 case omp::ClauseMemoryOrderKind::Acq_rel:
3027 return llvm::AtomicOrdering::AcquireRelease;
3028 case omp::ClauseMemoryOrderKind::Acquire:
3029 return llvm::AtomicOrdering::Acquire;
3030 case omp::ClauseMemoryOrderKind::Release:
3031 return llvm::AtomicOrdering::Release;
3032 case omp::ClauseMemoryOrderKind::Relaxed:
3033 return llvm::AtomicOrdering::Monotonic;
3034 }
3035 llvm_unreachable("Unknown ClauseMemoryOrderKind kind");
3036}
3037
3038/// Convert omp.atomic.read operation to LLVM IR.
3039static LogicalResult
3040convertOmpAtomicRead(Operation &opInst, llvm::IRBuilderBase &builder,
3041 LLVM::ModuleTranslation &moduleTranslation) {
3042 auto readOp = cast<omp::AtomicReadOp>(opInst);
3043 if (failed(Result: checkImplementationStatus(op&: opInst)))
3044 return failure();
3045
3046 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3047 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3048 findAllocaInsertPoint(builder, moduleTranslation);
3049
3050 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3051
3052 llvm::AtomicOrdering AO = convertAtomicOrdering(readOp.getMemoryOrder());
3053 llvm::Value *x = moduleTranslation.lookupValue(value: readOp.getX());
3054 llvm::Value *v = moduleTranslation.lookupValue(value: readOp.getV());
3055
3056 llvm::Type *elementType =
3057 moduleTranslation.convertType(type: readOp.getElementType());
3058
3059 llvm::OpenMPIRBuilder::AtomicOpValue V = {.Var: v, .ElemTy: elementType, .IsSigned: false, .IsVolatile: false};
3060 llvm::OpenMPIRBuilder::AtomicOpValue X = {.Var: x, .ElemTy: elementType, .IsSigned: false, .IsVolatile: false};
3061 builder.restoreIP(IP: ompBuilder->createAtomicRead(Loc: ompLoc, X, V, AO, AllocaIP: allocaIP));
3062 return success();
3063}
3064
3065/// Converts an omp.atomic.write operation to LLVM IR.
3066static LogicalResult
3067convertOmpAtomicWrite(Operation &opInst, llvm::IRBuilderBase &builder,
3068 LLVM::ModuleTranslation &moduleTranslation) {
3069 auto writeOp = cast<omp::AtomicWriteOp>(opInst);
3070 if (failed(Result: checkImplementationStatus(op&: opInst)))
3071 return failure();
3072
3073 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3074 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
3075 findAllocaInsertPoint(builder, moduleTranslation);
3076
3077 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3078 llvm::AtomicOrdering ao = convertAtomicOrdering(writeOp.getMemoryOrder());
3079 llvm::Value *expr = moduleTranslation.lookupValue(value: writeOp.getExpr());
3080 llvm::Value *dest = moduleTranslation.lookupValue(value: writeOp.getX());
3081 llvm::Type *ty = moduleTranslation.convertType(type: writeOp.getExpr().getType());
3082 llvm::OpenMPIRBuilder::AtomicOpValue x = {.Var: dest, .ElemTy: ty, /*isSigned=*/.IsSigned: false,
3083 /*isVolatile=*/.IsVolatile: false};
3084 builder.restoreIP(
3085 IP: ompBuilder->createAtomicWrite(Loc: ompLoc, X&: x, Expr: expr, AO: ao, AllocaIP: allocaIP));
3086 return success();
3087}
3088
3089/// Converts an LLVM dialect binary operation to the corresponding enum value
3090/// for `atomicrmw` supported binary operation.
3091llvm::AtomicRMWInst::BinOp convertBinOpToAtomic(Operation &op) {
3092 return llvm::TypeSwitch<Operation *, llvm::AtomicRMWInst::BinOp>(&op)
3093 .Case(caseFn: [&](LLVM::AddOp) { return llvm::AtomicRMWInst::BinOp::Add; })
3094 .Case(caseFn: [&](LLVM::SubOp) { return llvm::AtomicRMWInst::BinOp::Sub; })
3095 .Case(caseFn: [&](LLVM::AndOp) { return llvm::AtomicRMWInst::BinOp::And; })
3096 .Case(caseFn: [&](LLVM::OrOp) { return llvm::AtomicRMWInst::BinOp::Or; })
3097 .Case(caseFn: [&](LLVM::XOrOp) { return llvm::AtomicRMWInst::BinOp::Xor; })
3098 .Case(caseFn: [&](LLVM::UMaxOp) { return llvm::AtomicRMWInst::BinOp::UMax; })
3099 .Case(caseFn: [&](LLVM::UMinOp) { return llvm::AtomicRMWInst::BinOp::UMin; })
3100 .Case(caseFn: [&](LLVM::FAddOp) { return llvm::AtomicRMWInst::BinOp::FAdd; })
3101 .Case(caseFn: [&](LLVM::FSubOp) { return llvm::AtomicRMWInst::BinOp::FSub; })
3102 .Default(defaultResult: llvm::AtomicRMWInst::BinOp::BAD_BINOP);
3103}
3104
3105/// Converts an OpenMP atomic update operation using OpenMPIRBuilder.
3106static LogicalResult
3107convertOmpAtomicUpdate(omp::AtomicUpdateOp &opInst,
3108 llvm::IRBuilderBase &builder,
3109 LLVM::ModuleTranslation &moduleTranslation) {
3110 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3111 if (failed(checkImplementationStatus(*opInst)))
3112 return failure();
3113
3114 // Convert values and types.
3115 auto &innerOpList = opInst.getRegion().front().getOperations();
3116 bool isXBinopExpr{false};
3117 llvm::AtomicRMWInst::BinOp binop;
3118 mlir::Value mlirExpr;
3119 llvm::Value *llvmExpr = nullptr;
3120 llvm::Value *llvmX = nullptr;
3121 llvm::Type *llvmXElementType = nullptr;
3122 if (innerOpList.size() == 2) {
3123 // The two operations here are the update and the terminator.
3124 // Since we can identify the update operation, there is a possibility
3125 // that we can generate the atomicrmw instruction.
3126 mlir::Operation &innerOp = *opInst.getRegion().front().begin();
3127 if (!llvm::is_contained(innerOp.getOperands(),
3128 opInst.getRegion().getArgument(0))) {
3129 return opInst.emitError("no atomic update operation with region argument"
3130 " as operand found inside atomic.update region");
3131 }
3132 binop = convertBinOpToAtomic(op&: innerOp);
3133 isXBinopExpr = innerOp.getOperand(idx: 0) == opInst.getRegion().getArgument(0);
3134 mlirExpr = (isXBinopExpr ? innerOp.getOperand(idx: 1) : innerOp.getOperand(idx: 0));
3135 llvmExpr = moduleTranslation.lookupValue(value: mlirExpr);
3136 } else {
3137 // Since the update region includes more than one operation
3138 // we will resort to generating a cmpxchg loop.
3139 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3140 }
3141 llvmX = moduleTranslation.lookupValue(value: opInst.getX());
3142 llvmXElementType = moduleTranslation.convertType(
3143 type: opInst.getRegion().getArgument(0).getType());
3144 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {.Var: llvmX, .ElemTy: llvmXElementType,
3145 /*isSigned=*/.IsSigned: false,
3146 /*isVolatile=*/.IsVolatile: false};
3147
3148 llvm::AtomicOrdering atomicOrdering =
3149 convertAtomicOrdering(opInst.getMemoryOrder());
3150
3151 // Generate update code.
3152 auto updateFn =
3153 [&opInst, &moduleTranslation](
3154 llvm::Value *atomicx,
3155 llvm::IRBuilder<> &builder) -> llvm::Expected<llvm::Value *> {
3156 Block &bb = *opInst.getRegion().begin();
3157 moduleTranslation.mapValue(*opInst.getRegion().args_begin(), atomicx);
3158 moduleTranslation.mapBlock(mlir: &bb, llvm: builder.GetInsertBlock());
3159 if (failed(Result: moduleTranslation.convertBlock(bb, ignoreArguments: true, builder)))
3160 return llvm::make_error<PreviouslyReportedError>();
3161
3162 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.getTerminator());
3163 assert(yieldop && yieldop.getResults().size() == 1 &&
3164 "terminator must be omp.yield op and it must have exactly one "
3165 "argument");
3166 return moduleTranslation.lookupValue(value: yieldop.getResults()[0]);
3167 };
3168
3169 // Handle ambiguous alloca, if any.
3170 auto allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
3171 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3172 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3173 ompBuilder->createAtomicUpdate(Loc: ompLoc, AllocaIP: allocaIP, X&: llvmAtomicX, Expr: llvmExpr,
3174 AO: atomicOrdering, RMWOp: binop, UpdateOp: updateFn,
3175 IsXBinopExpr: isXBinopExpr);
3176
3177 if (failed(handleError(afterIP, *opInst)))
3178 return failure();
3179
3180 builder.restoreIP(IP: *afterIP);
3181 return success();
3182}
3183
3184static LogicalResult
3185convertOmpAtomicCapture(omp::AtomicCaptureOp atomicCaptureOp,
3186 llvm::IRBuilderBase &builder,
3187 LLVM::ModuleTranslation &moduleTranslation) {
3188 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3189 if (failed(checkImplementationStatus(*atomicCaptureOp)))
3190 return failure();
3191
3192 mlir::Value mlirExpr;
3193 bool isXBinopExpr = false, isPostfixUpdate = false;
3194 llvm::AtomicRMWInst::BinOp binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3195
3196 omp::AtomicUpdateOp atomicUpdateOp = atomicCaptureOp.getAtomicUpdateOp();
3197 omp::AtomicWriteOp atomicWriteOp = atomicCaptureOp.getAtomicWriteOp();
3198
3199 assert((atomicUpdateOp || atomicWriteOp) &&
3200 "internal op must be an atomic.update or atomic.write op");
3201
3202 if (atomicWriteOp) {
3203 isPostfixUpdate = true;
3204 mlirExpr = atomicWriteOp.getExpr();
3205 } else {
3206 isPostfixUpdate = atomicCaptureOp.getSecondOp() ==
3207 atomicCaptureOp.getAtomicUpdateOp().getOperation();
3208 auto &innerOpList = atomicUpdateOp.getRegion().front().getOperations();
3209 // Find the binary update operation that uses the region argument
3210 // and get the expression to update
3211 if (innerOpList.size() == 2) {
3212 mlir::Operation &innerOp = *atomicUpdateOp.getRegion().front().begin();
3213 if (!llvm::is_contained(innerOp.getOperands(),
3214 atomicUpdateOp.getRegion().getArgument(0))) {
3215 return atomicUpdateOp.emitError(
3216 "no atomic update operation with region argument"
3217 " as operand found inside atomic.update region");
3218 }
3219 binop = convertBinOpToAtomic(op&: innerOp);
3220 isXBinopExpr =
3221 innerOp.getOperand(idx: 0) == atomicUpdateOp.getRegion().getArgument(0);
3222 mlirExpr = (isXBinopExpr ? innerOp.getOperand(idx: 1) : innerOp.getOperand(idx: 0));
3223 } else {
3224 binop = llvm::AtomicRMWInst::BinOp::BAD_BINOP;
3225 }
3226 }
3227
3228 llvm::Value *llvmExpr = moduleTranslation.lookupValue(value: mlirExpr);
3229 llvm::Value *llvmX =
3230 moduleTranslation.lookupValue(value: atomicCaptureOp.getAtomicReadOp().getX());
3231 llvm::Value *llvmV =
3232 moduleTranslation.lookupValue(value: atomicCaptureOp.getAtomicReadOp().getV());
3233 llvm::Type *llvmXElementType = moduleTranslation.convertType(
3234 type: atomicCaptureOp.getAtomicReadOp().getElementType());
3235 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicX = {.Var: llvmX, .ElemTy: llvmXElementType,
3236 /*isSigned=*/.IsSigned: false,
3237 /*isVolatile=*/.IsVolatile: false};
3238 llvm::OpenMPIRBuilder::AtomicOpValue llvmAtomicV = {.Var: llvmV, .ElemTy: llvmXElementType,
3239 /*isSigned=*/.IsSigned: false,
3240 /*isVolatile=*/.IsVolatile: false};
3241
3242 llvm::AtomicOrdering atomicOrdering =
3243 convertAtomicOrdering(atomicCaptureOp.getMemoryOrder());
3244
3245 auto updateFn =
3246 [&](llvm::Value *atomicx,
3247 llvm::IRBuilder<> &builder) -> llvm::Expected<llvm::Value *> {
3248 if (atomicWriteOp)
3249 return moduleTranslation.lookupValue(value: atomicWriteOp.getExpr());
3250 Block &bb = *atomicUpdateOp.getRegion().begin();
3251 moduleTranslation.mapValue(*atomicUpdateOp.getRegion().args_begin(),
3252 atomicx);
3253 moduleTranslation.mapBlock(mlir: &bb, llvm: builder.GetInsertBlock());
3254 if (failed(Result: moduleTranslation.convertBlock(bb, ignoreArguments: true, builder)))
3255 return llvm::make_error<PreviouslyReportedError>();
3256
3257 omp::YieldOp yieldop = dyn_cast<omp::YieldOp>(bb.getTerminator());
3258 assert(yieldop && yieldop.getResults().size() == 1 &&
3259 "terminator must be omp.yield op and it must have exactly one "
3260 "argument");
3261 return moduleTranslation.lookupValue(value: yieldop.getResults()[0]);
3262 };
3263
3264 // Handle ambiguous alloca, if any.
3265 auto allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
3266 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3267 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3268 ompBuilder->createAtomicCapture(
3269 Loc: ompLoc, AllocaIP: allocaIP, X&: llvmAtomicX, V&: llvmAtomicV, Expr: llvmExpr, AO: atomicOrdering,
3270 RMWOp: binop, UpdateOp: updateFn, UpdateExpr: atomicUpdateOp, IsPostfixUpdate: isPostfixUpdate, IsXBinopExpr: isXBinopExpr);
3271
3272 if (failed(handleError(afterIP, *atomicCaptureOp)))
3273 return failure();
3274
3275 builder.restoreIP(IP: *afterIP);
3276 return success();
3277}
3278
3279static llvm::omp::Directive convertCancellationConstructType(
3280 omp::ClauseCancellationConstructType directive) {
3281 switch (directive) {
3282 case omp::ClauseCancellationConstructType::Loop:
3283 return llvm::omp::Directive::OMPD_for;
3284 case omp::ClauseCancellationConstructType::Parallel:
3285 return llvm::omp::Directive::OMPD_parallel;
3286 case omp::ClauseCancellationConstructType::Sections:
3287 return llvm::omp::Directive::OMPD_sections;
3288 case omp::ClauseCancellationConstructType::Taskgroup:
3289 return llvm::omp::Directive::OMPD_taskgroup;
3290 }
3291}
3292
3293static LogicalResult
3294convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder,
3295 LLVM::ModuleTranslation &moduleTranslation) {
3296 if (failed(checkImplementationStatus(*op.getOperation())))
3297 return failure();
3298
3299 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3300 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3301
3302 llvm::Value *ifCond = nullptr;
3303 if (Value ifVar = op.getIfExpr())
3304 ifCond = moduleTranslation.lookupValue(value: ifVar);
3305
3306 llvm::omp::Directive cancelledDirective =
3307 convertCancellationConstructType(op.getCancelDirective());
3308
3309 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3310 ompBuilder->createCancel(ompLoc, ifCond, cancelledDirective);
3311
3312 if (failed(handleError(afterIP, *op.getOperation())))
3313 return failure();
3314
3315 builder.restoreIP(IP: afterIP.get());
3316
3317 return success();
3318}
3319
3320static LogicalResult
3321convertOmpCancellationPoint(omp::CancellationPointOp op,
3322 llvm::IRBuilderBase &builder,
3323 LLVM::ModuleTranslation &moduleTranslation) {
3324 if (failed(checkImplementationStatus(*op.getOperation())))
3325 return failure();
3326
3327 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3328 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3329
3330 llvm::omp::Directive cancelledDirective =
3331 convertCancellationConstructType(op.getCancelDirective());
3332
3333 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
3334 ompBuilder->createCancellationPoint(ompLoc, cancelledDirective);
3335
3336 if (failed(handleError(afterIP, *op.getOperation())))
3337 return failure();
3338
3339 builder.restoreIP(IP: afterIP.get());
3340
3341 return success();
3342}
3343
3344/// Converts an OpenMP Threadprivate operation into LLVM IR using
3345/// OpenMPIRBuilder.
3346static LogicalResult
3347convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder,
3348 LLVM::ModuleTranslation &moduleTranslation) {
3349 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3350 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3351 auto threadprivateOp = cast<omp::ThreadprivateOp>(opInst);
3352
3353 if (failed(Result: checkImplementationStatus(op&: opInst)))
3354 return failure();
3355
3356 Value symAddr = threadprivateOp.getSymAddr();
3357 auto *symOp = symAddr.getDefiningOp();
3358
3359 if (auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(symOp))
3360 symOp = asCast.getOperand().getDefiningOp();
3361
3362 if (!isa<LLVM::AddressOfOp>(symOp))
3363 return opInst.emitError(message: "Addressing symbol not found");
3364 LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp);
3365
3366 LLVM::GlobalOp global =
3367 addressOfOp.getGlobal(moduleTranslation.symbolTable());
3368 llvm::GlobalValue *globalValue = moduleTranslation.lookupGlobal(op: global);
3369
3370 if (!ompBuilder->Config.isTargetDevice()) {
3371 llvm::Type *type = globalValue->getValueType();
3372 llvm::TypeSize typeSize =
3373 builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize(
3374 Ty: type);
3375 llvm::ConstantInt *size = builder.getInt64(C: typeSize.getFixedValue());
3376 llvm::Value *callInst = ompBuilder->createCachedThreadPrivate(
3377 Loc: ompLoc, Pointer: globalValue, Size: size, Name: global.getSymName() + ".cache");
3378 moduleTranslation.mapValue(mlir: opInst.getResult(idx: 0), llvm: callInst);
3379 } else {
3380 moduleTranslation.mapValue(mlir: opInst.getResult(idx: 0), llvm: globalValue);
3381 }
3382
3383 return success();
3384}
3385
3386static llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseKind
3387convertToDeviceClauseKind(mlir::omp::DeclareTargetDeviceType deviceClause) {
3388 switch (deviceClause) {
3389 case mlir::omp::DeclareTargetDeviceType::host:
3390 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseHost;
3391 break;
3392 case mlir::omp::DeclareTargetDeviceType::nohost:
3393 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseNoHost;
3394 break;
3395 case mlir::omp::DeclareTargetDeviceType::any:
3396 return llvm::OffloadEntriesInfoManager::OMPTargetDeviceClauseAny;
3397 break;
3398 }
3399 llvm_unreachable("unhandled device clause");
3400}
3401
3402static llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind
3403convertToCaptureClauseKind(
3404 mlir::omp::DeclareTargetCaptureClause captureClause) {
3405 switch (captureClause) {
3406 case mlir::omp::DeclareTargetCaptureClause::to:
3407 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
3408 case mlir::omp::DeclareTargetCaptureClause::link:
3409 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
3410 case mlir::omp::DeclareTargetCaptureClause::enter:
3411 return llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter;
3412 }
3413 llvm_unreachable("unhandled capture clause");
3414}
3415
3416static llvm::SmallString<64>
3417getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp,
3418 llvm::OpenMPIRBuilder &ompBuilder) {
3419 llvm::SmallString<64> suffix;
3420 llvm::raw_svector_ostream os(suffix);
3421 if (globalOp.getVisibility() == mlir::SymbolTable::Visibility::Private) {
3422 auto loc = globalOp->getLoc()->findInstanceOf<FileLineColLoc>();
3423 auto fileInfoCallBack = [&loc]() {
3424 return std::pair<std::string, uint64_t>(
3425 llvm::StringRef(loc.getFilename()), loc.getLine());
3426 };
3427
3428 os << llvm::format(
3429 "_%x", ompBuilder.getTargetEntryUniqueInfo(fileInfoCallBack).FileID);
3430 }
3431 os << "_decl_tgt_ref_ptr";
3432
3433 return suffix;
3434}
3435
3436static bool isDeclareTargetLink(mlir::Value value) {
3437 if (auto addressOfOp =
3438 llvm::dyn_cast_if_present<LLVM::AddressOfOp>(value.getDefiningOp())) {
3439 auto modOp = addressOfOp->getParentOfType<mlir::ModuleOp>();
3440 Operation *gOp = modOp.lookupSymbol(addressOfOp.getGlobalName());
3441 if (auto declareTargetGlobal =
3442 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(gOp))
3443 if (declareTargetGlobal.getDeclareTargetCaptureClause() ==
3444 mlir::omp::DeclareTargetCaptureClause::link)
3445 return true;
3446 }
3447 return false;
3448}
3449
3450// Returns the reference pointer generated by the lowering of the declare target
3451// operation in cases where the link clause is used or the to clause is used in
3452// USM mode.
3453static llvm::Value *
3454getRefPtrIfDeclareTarget(mlir::Value value,
3455 LLVM::ModuleTranslation &moduleTranslation) {
3456 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3457
3458 // An easier way to do this may just be to keep track of any pointer
3459 // references and their mapping to their respective operation
3460 if (auto addressOfOp =
3461 llvm::dyn_cast_if_present<LLVM::AddressOfOp>(value.getDefiningOp())) {
3462 if (auto gOp = llvm::dyn_cast_or_null<LLVM::GlobalOp>(
3463 addressOfOp->getParentOfType<mlir::ModuleOp>().lookupSymbol(
3464 addressOfOp.getGlobalName()))) {
3465
3466 if (auto declareTargetGlobal =
3467 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
3468 gOp.getOperation())) {
3469
3470 // In this case, we must utilise the reference pointer generated by the
3471 // declare target operation, similar to Clang
3472 if ((declareTargetGlobal.getDeclareTargetCaptureClause() ==
3473 mlir::omp::DeclareTargetCaptureClause::link) ||
3474 (declareTargetGlobal.getDeclareTargetCaptureClause() ==
3475 mlir::omp::DeclareTargetCaptureClause::to &&
3476 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
3477 llvm::SmallString<64> suffix =
3478 getDeclareTargetRefPtrSuffix(gOp, *ompBuilder);
3479
3480 if (gOp.getSymName().contains(suffix))
3481 return moduleTranslation.getLLVMModule()->getNamedValue(
3482 Name: gOp.getSymName());
3483
3484 return moduleTranslation.getLLVMModule()->getNamedValue(
3485 Name: (gOp.getSymName().str() + suffix.str()).str());
3486 }
3487 }
3488 }
3489 }
3490
3491 return nullptr;
3492}
3493
3494namespace {
3495// Append customMappers information to existing MapInfosTy
3496struct MapInfosTy : llvm::OpenMPIRBuilder::MapInfosTy {
3497 SmallVector<Operation *, 4> Mappers;
3498
3499 /// Append arrays in \a CurInfo.
3500 void append(MapInfosTy &curInfo) {
3501 Mappers.append(in_start: curInfo.Mappers.begin(), in_end: curInfo.Mappers.end());
3502 llvm::OpenMPIRBuilder::MapInfosTy::append(CurInfo&: curInfo);
3503 }
3504};
3505// A small helper structure to contain data gathered
3506// for map lowering and coalese it into one area and
3507// avoiding extra computations such as searches in the
3508// llvm module for lowered mapped variables or checking
3509// if something is declare target (and retrieving the
3510// value) more than neccessary.
3511struct MapInfoData : MapInfosTy {
3512 llvm::SmallVector<bool, 4> IsDeclareTarget;
3513 llvm::SmallVector<bool, 4> IsAMember;
3514 // Identify if mapping was added by mapClause or use_device clauses.
3515 llvm::SmallVector<bool, 4> IsAMapping;
3516 llvm::SmallVector<mlir::Operation *, 4> MapClause;
3517 llvm::SmallVector<llvm::Value *, 4> OriginalValue;
3518 // Stripped off array/pointer to get the underlying
3519 // element type
3520 llvm::SmallVector<llvm::Type *, 4> BaseType;
3521
3522 /// Append arrays in \a CurInfo.
3523 void append(MapInfoData &CurInfo) {
3524 IsDeclareTarget.append(in_start: CurInfo.IsDeclareTarget.begin(),
3525 in_end: CurInfo.IsDeclareTarget.end());
3526 MapClause.append(in_start: CurInfo.MapClause.begin(), in_end: CurInfo.MapClause.end());
3527 OriginalValue.append(in_start: CurInfo.OriginalValue.begin(),
3528 in_end: CurInfo.OriginalValue.end());
3529 BaseType.append(in_start: CurInfo.BaseType.begin(), in_end: CurInfo.BaseType.end());
3530 MapInfosTy::append(curInfo&: CurInfo);
3531 }
3532};
3533} // namespace
3534
3535uint64_t getArrayElementSizeInBits(LLVM::LLVMArrayType arrTy, DataLayout &dl) {
3536 if (auto nestedArrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(
3537 arrTy.getElementType()))
3538 return getArrayElementSizeInBits(nestedArrTy, dl);
3539 return dl.getTypeSizeInBits(t: arrTy.getElementType());
3540}
3541
3542// This function calculates the size to be offloaded for a specified type, given
3543// its associated map clause (which can contain bounds information which affects
3544// the total size), this size is calculated based on the underlying element type
3545// e.g. given a 1-D array of ints, we will calculate the size from the integer
3546// type * number of elements in the array. This size can be used in other
3547// calculations but is ultimately used as an argument to the OpenMP runtimes
3548// kernel argument structure which is generated through the combinedInfo data
3549// structures.
3550// This function is somewhat equivalent to Clang's getExprTypeSize inside of
3551// CGOpenMPRuntime.cpp.
3552llvm::Value *getSizeInBytes(DataLayout &dl, const mlir::Type &type,
3553 Operation *clauseOp, llvm::Value *basePointer,
3554 llvm::Type *baseType, llvm::IRBuilderBase &builder,
3555 LLVM::ModuleTranslation &moduleTranslation) {
3556 if (auto memberClause =
3557 mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(clauseOp)) {
3558 // This calculates the size to transfer based on bounds and the underlying
3559 // element type, provided bounds have been specified (Fortran
3560 // pointers/allocatables/target and arrays that have sections specified fall
3561 // into this as well).
3562 if (!memberClause.getBounds().empty()) {
3563 llvm::Value *elementCount = builder.getInt64(C: 1);
3564 for (auto bounds : memberClause.getBounds()) {
3565 if (auto boundOp = mlir::dyn_cast_if_present<mlir::omp::MapBoundsOp>(
3566 bounds.getDefiningOp())) {
3567 // The below calculation for the size to be mapped calculated from the
3568 // map.info's bounds is: (elemCount * [UB - LB] + 1), later we
3569 // multiply by the underlying element types byte size to get the full
3570 // size to be offloaded based on the bounds
3571 elementCount = builder.CreateMul(
3572 elementCount,
3573 builder.CreateAdd(
3574 builder.CreateSub(
3575 moduleTranslation.lookupValue(boundOp.getUpperBound()),
3576 moduleTranslation.lookupValue(boundOp.getLowerBound())),
3577 builder.getInt64(1)));
3578 }
3579 }
3580
3581 // utilising getTypeSizeInBits instead of getTypeSize as getTypeSize gives
3582 // the size in inconsistent byte or bit format.
3583 uint64_t underlyingTypeSzInBits = dl.getTypeSizeInBits(t: type);
3584 if (auto arrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(type))
3585 underlyingTypeSzInBits = getArrayElementSizeInBits(arrTy, dl);
3586
3587 // The size in bytes x number of elements, the sizeInBytes stored is
3588 // the underyling types size, e.g. if ptr<i32>, it'll be the i32's
3589 // size, so we do some on the fly runtime math to get the size in
3590 // bytes from the extent (ub - lb) * sizeInBytes. NOTE: This may need
3591 // some adjustment for members with more complex types.
3592 return builder.CreateMul(LHS: elementCount,
3593 RHS: builder.getInt64(C: underlyingTypeSzInBits / 8));
3594 }
3595 }
3596
3597 return builder.getInt64(C: dl.getTypeSizeInBits(t: type) / 8);
3598}
3599
3600static void collectMapDataFromMapOperands(
3601 MapInfoData &mapData, SmallVectorImpl<Value> &mapVars,
3602 LLVM::ModuleTranslation &moduleTranslation, DataLayout &dl,
3603 llvm::IRBuilderBase &builder, ArrayRef<Value> useDevPtrOperands = {},
3604 ArrayRef<Value> useDevAddrOperands = {},
3605 ArrayRef<Value> hasDevAddrOperands = {}) {
3606 auto checkIsAMember = [](const auto &mapVars, auto mapOp) {
3607 // Check if this is a member mapping and correctly assign that it is, if
3608 // it is a member of a larger object.
3609 // TODO: Need better handling of members, and distinguishing of members
3610 // that are implicitly allocated on device vs explicitly passed in as
3611 // arguments.
3612 // TODO: May require some further additions to support nested record
3613 // types, i.e. member maps that can have member maps.
3614 for (Value mapValue : mapVars) {
3615 auto map = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
3616 for (auto member : map.getMembers())
3617 if (member == mapOp)
3618 return true;
3619 }
3620 return false;
3621 };
3622
3623 // Process MapOperands
3624 for (Value mapValue : mapVars) {
3625 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
3626 Value offloadPtr =
3627 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
3628 mapData.OriginalValue.push_back(Elt: moduleTranslation.lookupValue(value: offloadPtr));
3629 mapData.Pointers.push_back(Elt: mapData.OriginalValue.back());
3630
3631 if (llvm::Value *refPtr =
3632 getRefPtrIfDeclareTarget(value: offloadPtr,
3633 moduleTranslation)) { // declare target
3634 mapData.IsDeclareTarget.push_back(Elt: true);
3635 mapData.BasePointers.push_back(Elt: refPtr);
3636 } else { // regular mapped variable
3637 mapData.IsDeclareTarget.push_back(Elt: false);
3638 mapData.BasePointers.push_back(Elt: mapData.OriginalValue.back());
3639 }
3640
3641 mapData.BaseType.push_back(
3642 Elt: moduleTranslation.convertType(type: mapOp.getVarType()));
3643 mapData.Sizes.push_back(
3644 Elt: getSizeInBytes(dl, mapOp.getVarType(), mapOp, mapData.Pointers.back(),
3645 mapData.BaseType.back(), builder, moduleTranslation));
3646 mapData.MapClause.push_back(Elt: mapOp.getOperation());
3647 mapData.Types.push_back(
3648 Elt: llvm::omp::OpenMPOffloadMappingFlags(mapOp.getMapType()));
3649 mapData.Names.push_back(Elt: LLVM::createMappingInformation(
3650 loc: mapOp.getLoc(), builder&: *moduleTranslation.getOpenMPBuilder()));
3651 mapData.DevicePointers.push_back(Elt: llvm::OpenMPIRBuilder::DeviceInfoTy::None);
3652 if (mapOp.getMapperId())
3653 mapData.Mappers.push_back(
3654 SymbolTable::lookupNearestSymbolFrom<omp::DeclareMapperOp>(
3655 mapOp, mapOp.getMapperIdAttr()));
3656 else
3657 mapData.Mappers.push_back(Elt: nullptr);
3658 mapData.IsAMapping.push_back(Elt: true);
3659 mapData.IsAMember.push_back(Elt: checkIsAMember(mapVars, mapOp));
3660 }
3661
3662 auto findMapInfo = [&mapData](llvm::Value *val,
3663 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
3664 unsigned index = 0;
3665 bool found = false;
3666 for (llvm::Value *basePtr : mapData.OriginalValue) {
3667 if (basePtr == val && mapData.IsAMapping[index]) {
3668 found = true;
3669 mapData.Types[index] |=
3670 llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM;
3671 mapData.DevicePointers[index] = devInfoTy;
3672 }
3673 index++;
3674 }
3675 return found;
3676 };
3677
3678 // Process useDevPtr(Addr)Operands
3679 auto addDevInfos = [&](const llvm::ArrayRef<Value> &useDevOperands,
3680 llvm::OpenMPIRBuilder::DeviceInfoTy devInfoTy) {
3681 for (Value mapValue : useDevOperands) {
3682 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
3683 Value offloadPtr =
3684 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
3685 llvm::Value *origValue = moduleTranslation.lookupValue(value: offloadPtr);
3686
3687 // Check if map info is already present for this entry.
3688 if (!findMapInfo(origValue, devInfoTy)) {
3689 mapData.OriginalValue.push_back(Elt: origValue);
3690 mapData.Pointers.push_back(Elt: mapData.OriginalValue.back());
3691 mapData.IsDeclareTarget.push_back(Elt: false);
3692 mapData.BasePointers.push_back(Elt: mapData.OriginalValue.back());
3693 mapData.BaseType.push_back(
3694 Elt: moduleTranslation.convertType(type: mapOp.getVarType()));
3695 mapData.Sizes.push_back(Elt: builder.getInt64(C: 0));
3696 mapData.MapClause.push_back(Elt: mapOp.getOperation());
3697 mapData.Types.push_back(
3698 Elt: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_RETURN_PARAM);
3699 mapData.Names.push_back(Elt: LLVM::createMappingInformation(
3700 loc: mapOp.getLoc(), builder&: *moduleTranslation.getOpenMPBuilder()));
3701 mapData.DevicePointers.push_back(Elt: devInfoTy);
3702 mapData.Mappers.push_back(Elt: nullptr);
3703 mapData.IsAMapping.push_back(Elt: false);
3704 mapData.IsAMember.push_back(Elt: checkIsAMember(useDevOperands, mapOp));
3705 }
3706 }
3707 };
3708
3709 addDevInfos(useDevAddrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
3710 addDevInfos(useDevPtrOperands, llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer);
3711
3712 for (Value mapValue : hasDevAddrOperands) {
3713 auto mapOp = cast<omp::MapInfoOp>(mapValue.getDefiningOp());
3714 Value offloadPtr =
3715 mapOp.getVarPtrPtr() ? mapOp.getVarPtrPtr() : mapOp.getVarPtr();
3716 llvm::Value *origValue = moduleTranslation.lookupValue(value: offloadPtr);
3717 auto mapType =
3718 static_cast<llvm::omp::OpenMPOffloadMappingFlags>(mapOp.getMapType());
3719 auto mapTypeAlways = llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
3720
3721 mapData.OriginalValue.push_back(Elt: origValue);
3722 mapData.BasePointers.push_back(Elt: origValue);
3723 mapData.Pointers.push_back(Elt: origValue);
3724 mapData.IsDeclareTarget.push_back(Elt: false);
3725 mapData.BaseType.push_back(
3726 Elt: moduleTranslation.convertType(type: mapOp.getVarType()));
3727 mapData.Sizes.push_back(
3728 Elt: builder.getInt64(C: dl.getTypeSize(t: mapOp.getVarType())));
3729 mapData.MapClause.push_back(Elt: mapOp.getOperation());
3730 if (llvm::to_underlying(mapType & mapTypeAlways)) {
3731 // Descriptors are mapped with the ALWAYS flag, since they can get
3732 // rematerialized, so the address of the decriptor for a given object
3733 // may change from one place to another.
3734 mapData.Types.push_back(Elt: mapType);
3735 // Technically it's possible for a non-descriptor mapping to have
3736 // both has-device-addr and ALWAYS, so lookup the mapper in case it
3737 // exists.
3738 if (mapOp.getMapperId()) {
3739 mapData.Mappers.push_back(
3740 SymbolTable::lookupNearestSymbolFrom<omp::DeclareMapperOp>(
3741 mapOp, mapOp.getMapperIdAttr()));
3742 } else {
3743 mapData.Mappers.push_back(Elt: nullptr);
3744 }
3745 } else {
3746 mapData.Types.push_back(
3747 Elt: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL);
3748 mapData.Mappers.push_back(Elt: nullptr);
3749 }
3750 mapData.Names.push_back(Elt: LLVM::createMappingInformation(
3751 loc: mapOp.getLoc(), builder&: *moduleTranslation.getOpenMPBuilder()));
3752 mapData.DevicePointers.push_back(
3753 Elt: llvm::OpenMPIRBuilder::DeviceInfoTy::Address);
3754 mapData.IsAMapping.push_back(Elt: false);
3755 mapData.IsAMember.push_back(Elt: checkIsAMember(hasDevAddrOperands, mapOp));
3756 }
3757}
3758
3759static int getMapDataMemberIdx(MapInfoData &mapData, omp::MapInfoOp memberOp) {
3760 auto *res = llvm::find(mapData.MapClause, memberOp);
3761 assert(res != mapData.MapClause.end() &&
3762 "MapInfoOp for member not found in MapData, cannot return index");
3763 return std::distance(mapData.MapClause.begin(), res);
3764}
3765
3766static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo,
3767 bool first) {
3768 ArrayAttr indexAttr = mapInfo.getMembersIndexAttr();
3769 // Only 1 member has been mapped, we can return it.
3770 if (indexAttr.size() == 1)
3771 return cast<omp::MapInfoOp>(mapInfo.getMembers()[0].getDefiningOp());
3772
3773 llvm::SmallVector<size_t> indices(indexAttr.size());
3774 std::iota(first: indices.begin(), last: indices.end(), value: 0);
3775
3776 llvm::sort(Start: indices.begin(), End: indices.end(),
3777 Comp: [&](const size_t a, const size_t b) {
3778 auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
3779 auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
3780 for (const auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
3781 int64_t aIndex = cast<IntegerAttr>(std::get<0>(it)).getInt();
3782 int64_t bIndex = cast<IntegerAttr>(std::get<1>(it)).getInt();
3783
3784 if (aIndex == bIndex)
3785 continue;
3786
3787 if (aIndex < bIndex)
3788 return first;
3789
3790 if (aIndex > bIndex)
3791 return !first;
3792 }
3793
3794 // Iterated the up until the end of the smallest member and
3795 // they were found to be equal up to that point, so select
3796 // the member with the lowest index count, so the "parent"
3797 return memberIndicesA.size() < memberIndicesB.size();
3798 });
3799
3800 return llvm::cast<omp::MapInfoOp>(
3801 mapInfo.getMembers()[indices.front()].getDefiningOp());
3802}
3803
3804/// This function calculates the array/pointer offset for map data provided
3805/// with bounds operations, e.g. when provided something like the following:
3806///
3807/// Fortran
3808/// map(tofrom: array(2:5, 3:2))
3809/// or
3810/// C++
3811/// map(tofrom: array[1:4][2:3])
3812/// We must calculate the initial pointer offset to pass across, this function
3813/// performs this using bounds.
3814///
3815/// NOTE: which while specified in row-major order it currently needs to be
3816/// flipped for Fortran's column order array allocation and access (as
3817/// opposed to C++'s row-major, hence the backwards processing where order is
3818/// important). This is likely important to keep in mind for the future when
3819/// we incorporate a C++ frontend, both frontends will need to agree on the
3820/// ordering of generated bounds operations (one may have to flip them) to
3821/// make the below lowering frontend agnostic. The offload size
3822/// calcualtion may also have to be adjusted for C++.
3823std::vector<llvm::Value *>
3824calculateBoundsOffset(LLVM::ModuleTranslation &moduleTranslation,
3825 llvm::IRBuilderBase &builder, bool isArrayTy,
3826 OperandRange bounds) {
3827 std::vector<llvm::Value *> idx;
3828 // There's no bounds to calculate an offset from, we can safely
3829 // ignore and return no indices.
3830 if (bounds.empty())
3831 return idx;
3832
3833 // If we have an array type, then we have its type so can treat it as a
3834 // normal GEP instruction where the bounds operations are simply indexes
3835 // into the array. We currently do reverse order of the bounds, which
3836 // I believe leans more towards Fortran's column-major in memory.
3837 if (isArrayTy) {
3838 idx.push_back(x: builder.getInt64(C: 0));
3839 for (int i = bounds.size() - 1; i >= 0; --i) {
3840 if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
3841 bounds[i].getDefiningOp())) {
3842 idx.push_back(moduleTranslation.lookupValue(value: boundOp.getLowerBound()));
3843 }
3844 }
3845 } else {
3846 // If we do not have an array type, but we have bounds, then we're dealing
3847 // with a pointer that's being treated like an array and we have the
3848 // underlying type e.g. an i32, or f64 etc, e.g. a fortran descriptor base
3849 // address (pointer pointing to the actual data) so we must caclulate the
3850 // offset using a single index which the following two loops attempts to
3851 // compute.
3852
3853 // Calculates the size offset we need to make per row e.g. first row or
3854 // column only needs to be offset by one, but the next would have to be
3855 // the previous row/column offset multiplied by the extent of current row.
3856 //
3857 // For example ([1][10][100]):
3858 //
3859 // - First row/column we move by 1 for each index increment
3860 // - Second row/column we move by 1 (first row/column) * 10 (extent/size of
3861 // current) for 10 for each index increment
3862 // - Third row/column we would move by 10 (second row/column) *
3863 // (extent/size of current) 100 for 1000 for each index increment
3864 std::vector<llvm::Value *> dimensionIndexSizeOffset{builder.getInt64(C: 1)};
3865 for (size_t i = 1; i < bounds.size(); ++i) {
3866 if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
3867 bounds[i].getDefiningOp())) {
3868 dimensionIndexSizeOffset.push_back(builder.CreateMul(
3869 LHS: moduleTranslation.lookupValue(value: boundOp.getExtent()),
3870 RHS: dimensionIndexSizeOffset[i - 1]));
3871 }
3872 }
3873
3874 // Now that we have calculated how much we move by per index, we must
3875 // multiply each lower bound offset in indexes by the size offset we
3876 // have calculated in the previous and accumulate the results to get
3877 // our final resulting offset.
3878 for (int i = bounds.size() - 1; i >= 0; --i) {
3879 if (auto boundOp = dyn_cast_if_present<omp::MapBoundsOp>(
3880 bounds[i].getDefiningOp())) {
3881 if (idx.empty())
3882 idx.emplace_back(builder.CreateMul(
3883 LHS: moduleTranslation.lookupValue(value: boundOp.getLowerBound()),
3884 RHS: dimensionIndexSizeOffset[i]));
3885 else
3886 idx.back() = builder.CreateAdd(
3887 LHS: idx.back(), RHS: builder.CreateMul(LHS: moduleTranslation.lookupValue(
3888 value: boundOp.getLowerBound()),
3889 RHS: dimensionIndexSizeOffset[i]));
3890 }
3891 }
3892 }
3893
3894 return idx;
3895}
3896
3897// This creates two insertions into the MapInfosTy data structure for the
3898// "parent" of a set of members, (usually a container e.g.
3899// class/structure/derived type) when subsequent members have also been
3900// explicitly mapped on the same map clause. Certain types, such as Fortran
3901// descriptors are mapped like this as well, however, the members are
3902// implicit as far as a user is concerned, but we must explicitly map them
3903// internally.
3904//
3905// This function also returns the memberOfFlag for this particular parent,
3906// which is utilised in subsequent member mappings (by modifying there map type
3907// with it) to indicate that a member is part of this parent and should be
3908// treated by the runtime as such. Important to achieve the correct mapping.
3909//
3910// This function borrows a lot from Clang's emitCombinedEntry function
3911// inside of CGOpenMPRuntime.cpp
3912static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
3913 LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
3914 llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo,
3915 MapInfoData &mapData, uint64_t mapDataIndex, bool isTargetParams) {
3916 assert(!ompBuilder.Config.isTargetDevice() &&
3917 "function only supported for host device codegen");
3918
3919 // Map the first segment of our structure
3920 combinedInfo.Types.emplace_back(
3921 Args: isTargetParams
3922 ? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM
3923 : llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE);
3924 combinedInfo.DevicePointers.emplace_back(
3925 Args&: mapData.DevicePointers[mapDataIndex]);
3926 combinedInfo.Mappers.emplace_back(Args&: mapData.Mappers[mapDataIndex]);
3927 combinedInfo.Names.emplace_back(Args: LLVM::createMappingInformation(
3928 loc: mapData.MapClause[mapDataIndex]->getLoc(), builder&: ompBuilder));
3929 combinedInfo.BasePointers.emplace_back(Args&: mapData.BasePointers[mapDataIndex]);
3930
3931 // Calculate size of the parent object being mapped based on the
3932 // addresses at runtime, highAddr - lowAddr = size. This of course
3933 // doesn't factor in allocated data like pointers, hence the further
3934 // processing of members specified by users, or in the case of
3935 // Fortran pointers and allocatables, the mapping of the pointed to
3936 // data by the descriptor (which itself, is a structure containing
3937 // runtime information on the dynamically allocated data).
3938 auto parentClause =
3939 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
3940
3941 llvm::Value *lowAddr, *highAddr;
3942 if (!parentClause.getPartialMap()) {
3943 lowAddr = builder.CreatePointerCast(V: mapData.Pointers[mapDataIndex],
3944 DestTy: builder.getPtrTy());
3945 highAddr = builder.CreatePointerCast(
3946 V: builder.CreateConstGEP1_32(Ty: mapData.BaseType[mapDataIndex],
3947 Ptr: mapData.Pointers[mapDataIndex], Idx0: 1),
3948 DestTy: builder.getPtrTy());
3949 combinedInfo.Pointers.emplace_back(Args&: mapData.Pointers[mapDataIndex]);
3950 } else {
3951 auto mapOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
3952 int firstMemberIdx = getMapDataMemberIdx(
3953 mapData, getFirstOrLastMappedMemberPtr(mapOp, true));
3954 lowAddr = builder.CreatePointerCast(V: mapData.Pointers[firstMemberIdx],
3955 DestTy: builder.getPtrTy());
3956 int lastMemberIdx = getMapDataMemberIdx(
3957 mapData, getFirstOrLastMappedMemberPtr(mapOp, false));
3958 highAddr = builder.CreatePointerCast(
3959 V: builder.CreateGEP(Ty: mapData.BaseType[lastMemberIdx],
3960 Ptr: mapData.Pointers[lastMemberIdx], IdxList: builder.getInt64(C: 1)),
3961 DestTy: builder.getPtrTy());
3962 combinedInfo.Pointers.emplace_back(Args&: mapData.Pointers[firstMemberIdx]);
3963 }
3964
3965 llvm::Value *size = builder.CreateIntCast(
3966 V: builder.CreatePtrDiff(ElemTy: builder.getInt8Ty(), LHS: highAddr, RHS: lowAddr),
3967 DestTy: builder.getInt64Ty(),
3968 /*isSigned=*/false);
3969 combinedInfo.Sizes.push_back(Elt: size);
3970
3971 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag =
3972 ompBuilder.getMemberOfFlag(Position: combinedInfo.BasePointers.size() - 1);
3973
3974 // This creates the initial MEMBER_OF mapping that consists of
3975 // the parent/top level container (same as above effectively, except
3976 // with a fixed initial compile time size and separate maptype which
3977 // indicates the true mape type (tofrom etc.). This parent mapping is
3978 // only relevant if the structure in its totality is being mapped,
3979 // otherwise the above suffices.
3980 if (!parentClause.getPartialMap()) {
3981 // TODO: This will need to be expanded to include the whole host of logic
3982 // for the map flags that Clang currently supports (e.g. it should do some
3983 // further case specific flag modifications). For the moment, it handles
3984 // what we support as expected.
3985 llvm::omp::OpenMPOffloadMappingFlags mapFlag = mapData.Types[mapDataIndex];
3986 ompBuilder.setCorrectMemberOfFlag(Flags&: mapFlag, MemberOfFlag: memberOfFlag);
3987 combinedInfo.Types.emplace_back(Args&: mapFlag);
3988 combinedInfo.DevicePointers.emplace_back(
3989 Args: llvm::OpenMPIRBuilder::DeviceInfoTy::None);
3990 combinedInfo.Mappers.emplace_back(Args: nullptr);
3991 combinedInfo.Names.emplace_back(Args: LLVM::createMappingInformation(
3992 loc: mapData.MapClause[mapDataIndex]->getLoc(), builder&: ompBuilder));
3993 combinedInfo.BasePointers.emplace_back(Args&: mapData.BasePointers[mapDataIndex]);
3994 combinedInfo.Pointers.emplace_back(Args&: mapData.Pointers[mapDataIndex]);
3995 combinedInfo.Sizes.emplace_back(Args&: mapData.Sizes[mapDataIndex]);
3996 }
3997 return memberOfFlag;
3998}
3999
4000// The intent is to verify if the mapped data being passed is a
4001// pointer -> pointee that requires special handling in certain cases,
4002// e.g. applying the OMP_MAP_PTR_AND_OBJ map type.
4003//
4004// There may be a better way to verify this, but unfortunately with
4005// opaque pointers we lose the ability to easily check if something is
4006// a pointer whilst maintaining access to the underlying type.
4007static bool checkIfPointerMap(omp::MapInfoOp mapOp) {
4008 // If we have a varPtrPtr field assigned then the underlying type is a pointer
4009 if (mapOp.getVarPtrPtr())
4010 return true;
4011
4012 // If the map data is declare target with a link clause, then it's represented
4013 // as a pointer when we lower it to LLVM-IR even if at the MLIR level it has
4014 // no relation to pointers.
4015 if (isDeclareTargetLink(mapOp.getVarPtr()))
4016 return true;
4017
4018 return false;
4019}
4020
4021// This function is intended to add explicit mappings of members
4022static void processMapMembersWithParent(
4023 LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder,
4024 llvm::OpenMPIRBuilder &ompBuilder, DataLayout &dl, MapInfosTy &combinedInfo,
4025 MapInfoData &mapData, uint64_t mapDataIndex,
4026 llvm::omp::OpenMPOffloadMappingFlags memberOfFlag) {
4027 assert(!ompBuilder.Config.isTargetDevice() &&
4028 "function only supported for host device codegen");
4029
4030 auto parentClause =
4031 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
4032
4033 for (auto mappedMembers : parentClause.getMembers()) {
4034 auto memberClause =
4035 llvm::cast<omp::MapInfoOp>(mappedMembers.getDefiningOp());
4036 int memberDataIdx = getMapDataMemberIdx(mapData, memberClause);
4037
4038 assert(memberDataIdx >= 0 && "could not find mapped member of structure");
4039
4040 // If we're currently mapping a pointer to a block of data, we must
4041 // initially map the pointer, and then attatch/bind the data with a
4042 // subsequent map to the pointer. This segment of code generates the
4043 // pointer mapping, which can in certain cases be optimised out as Clang
4044 // currently does in its lowering. However, for the moment we do not do so,
4045 // in part as we currently have substantially less information on the data
4046 // being mapped at this stage.
4047 if (checkIfPointerMap(memberClause)) {
4048 auto mapFlag =
4049 llvm::omp::OpenMPOffloadMappingFlags(memberClause.getMapType());
4050 mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
4051 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
4052 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
4053 combinedInfo.Types.emplace_back(mapFlag);
4054 combinedInfo.DevicePointers.emplace_back(
4055 llvm::OpenMPIRBuilder::DeviceInfoTy::None);
4056 combinedInfo.Mappers.emplace_back(nullptr);
4057 combinedInfo.Names.emplace_back(
4058 LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder));
4059 combinedInfo.BasePointers.emplace_back(
4060 mapData.BasePointers[mapDataIndex]);
4061 combinedInfo.Pointers.emplace_back(mapData.BasePointers[memberDataIdx]);
4062 combinedInfo.Sizes.emplace_back(builder.getInt64(
4063 moduleTranslation.getLLVMModule()->getDataLayout().getPointerSize()));
4064 }
4065
4066 // Same MemberOfFlag to indicate its link with parent and other members
4067 // of.
4068 auto mapFlag =
4069 llvm::omp::OpenMPOffloadMappingFlags(memberClause.getMapType());
4070 mapFlag &= ~llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
4071 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
4072 ompBuilder.setCorrectMemberOfFlag(mapFlag, memberOfFlag);
4073 if (checkIfPointerMap(memberClause))
4074 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
4075
4076 combinedInfo.Types.emplace_back(mapFlag);
4077 combinedInfo.DevicePointers.emplace_back(
4078 mapData.DevicePointers[memberDataIdx]);
4079 combinedInfo.Mappers.emplace_back(mapData.Mappers[memberDataIdx]);
4080 combinedInfo.Names.emplace_back(
4081 LLVM::createMappingInformation(memberClause.getLoc(), ompBuilder));
4082 uint64_t basePointerIndex =
4083 checkIfPointerMap(memberClause) ? memberDataIdx : mapDataIndex;
4084 combinedInfo.BasePointers.emplace_back(
4085 mapData.BasePointers[basePointerIndex]);
4086 combinedInfo.Pointers.emplace_back(mapData.Pointers[memberDataIdx]);
4087
4088 llvm::Value *size = mapData.Sizes[memberDataIdx];
4089 if (checkIfPointerMap(memberClause)) {
4090 size = builder.CreateSelect(
4091 builder.CreateIsNull(mapData.Pointers[memberDataIdx]),
4092 builder.getInt64(0), size);
4093 }
4094
4095 combinedInfo.Sizes.emplace_back(size);
4096 }
4097}
4098
4099static void processIndividualMap(MapInfoData &mapData, size_t mapDataIdx,
4100 MapInfosTy &combinedInfo, bool isTargetParams,
4101 int mapDataParentIdx = -1) {
4102 // Declare Target Mappings are excluded from being marked as
4103 // OMP_MAP_TARGET_PARAM as they are not passed as parameters, they're
4104 // marked with OMP_MAP_PTR_AND_OBJ instead.
4105 auto mapFlag = mapData.Types[mapDataIdx];
4106 auto mapInfoOp = llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIdx]);
4107
4108 bool isPtrTy = checkIfPointerMap(mapInfoOp);
4109 if (isPtrTy)
4110 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ;
4111
4112 if (isTargetParams && !mapData.IsDeclareTarget[mapDataIdx])
4113 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM;
4114
4115 if (mapInfoOp.getMapCaptureType() == omp::VariableCaptureKind::ByCopy &&
4116 !isPtrTy)
4117 mapFlag |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_LITERAL;
4118
4119 // if we're provided a mapDataParentIdx, then the data being mapped is
4120 // part of a larger object (in a parent <-> member mapping) and in this
4121 // case our BasePointer should be the parent.
4122 if (mapDataParentIdx >= 0)
4123 combinedInfo.BasePointers.emplace_back(
4124 Args&: mapData.BasePointers[mapDataParentIdx]);
4125 else
4126 combinedInfo.BasePointers.emplace_back(Args&: mapData.BasePointers[mapDataIdx]);
4127
4128 combinedInfo.Pointers.emplace_back(Args&: mapData.Pointers[mapDataIdx]);
4129 combinedInfo.DevicePointers.emplace_back(Args&: mapData.DevicePointers[mapDataIdx]);
4130 combinedInfo.Mappers.emplace_back(Args&: mapData.Mappers[mapDataIdx]);
4131 combinedInfo.Names.emplace_back(Args&: mapData.Names[mapDataIdx]);
4132 combinedInfo.Types.emplace_back(Args&: mapFlag);
4133 combinedInfo.Sizes.emplace_back(Args&: mapData.Sizes[mapDataIdx]);
4134}
4135
4136static void processMapWithMembersOf(LLVM::ModuleTranslation &moduleTranslation,
4137 llvm::IRBuilderBase &builder,
4138 llvm::OpenMPIRBuilder &ompBuilder,
4139 DataLayout &dl, MapInfosTy &combinedInfo,
4140 MapInfoData &mapData, uint64_t mapDataIndex,
4141 bool isTargetParams) {
4142 assert(!ompBuilder.Config.isTargetDevice() &&
4143 "function only supported for host device codegen");
4144
4145 auto parentClause =
4146 llvm::cast<omp::MapInfoOp>(mapData.MapClause[mapDataIndex]);
4147
4148 // If we have a partial map (no parent referenced in the map clauses of the
4149 // directive, only members) and only a single member, we do not need to bind
4150 // the map of the member to the parent, we can pass the member separately.
4151 if (parentClause.getMembers().size() == 1 && parentClause.getPartialMap()) {
4152 auto memberClause = llvm::cast<omp::MapInfoOp>(
4153 parentClause.getMembers()[0].getDefiningOp());
4154 int memberDataIdx = getMapDataMemberIdx(mapData, memberClause);
4155 // Note: Clang treats arrays with explicit bounds that fall into this
4156 // category as a parent with map case, however, it seems this isn't a
4157 // requirement, and processing them as an individual map is fine. So,
4158 // we will handle them as individual maps for the moment, as it's
4159 // difficult for us to check this as we always require bounds to be
4160 // specified currently and it's also marginally more optimal (single
4161 // map rather than two). The difference may come from the fact that
4162 // Clang maps array without bounds as pointers (which we do not
4163 // currently do), whereas we treat them as arrays in all cases
4164 // currently.
4165 processIndividualMap(mapData, mapDataIdx: memberDataIdx, combinedInfo, isTargetParams,
4166 mapDataParentIdx: mapDataIndex);
4167 return;
4168 }
4169
4170 llvm::omp::OpenMPOffloadMappingFlags memberOfParentFlag =
4171 mapParentWithMembers(moduleTranslation, builder, ompBuilder, dl,
4172 combinedInfo, mapData, mapDataIndex, isTargetParams);
4173 processMapMembersWithParent(moduleTranslation, builder, ompBuilder, dl,
4174 combinedInfo, mapData, mapDataIndex,
4175 memberOfFlag: memberOfParentFlag);
4176}
4177
4178// This is a variation on Clang's GenerateOpenMPCapturedVars, which
4179// generates different operation (e.g. load/store) combinations for
4180// arguments to the kernel, based on map capture kinds which are then
4181// utilised in the combinedInfo in place of the original Map value.
4182static void
4183createAlteredByCaptureMap(MapInfoData &mapData,
4184 LLVM::ModuleTranslation &moduleTranslation,
4185 llvm::IRBuilderBase &builder) {
4186 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
4187 "function only supported for host device codegen");
4188 for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
4189 // if it's declare target, skip it, it's handled separately.
4190 if (!mapData.IsDeclareTarget[i]) {
4191 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
4192 omp::VariableCaptureKind captureKind = mapOp.getMapCaptureType();
4193 bool isPtrTy = checkIfPointerMap(mapOp);
4194
4195 // Currently handles array sectioning lowerbound case, but more
4196 // logic may be required in the future. Clang invokes EmitLValue,
4197 // which has specialised logic for special Clang types such as user
4198 // defines, so it is possible we will have to extend this for
4199 // structures or other complex types. As the general idea is that this
4200 // function mimics some of the logic from Clang that we require for
4201 // kernel argument passing from host -> device.
4202 switch (captureKind) {
4203 case omp::VariableCaptureKind::ByRef: {
4204 llvm::Value *newV = mapData.Pointers[i];
4205 std::vector<llvm::Value *> offsetIdx = calculateBoundsOffset(
4206 moduleTranslation, builder, mapData.BaseType[i]->isArrayTy(),
4207 mapOp.getBounds());
4208 if (isPtrTy)
4209 newV = builder.CreateLoad(Ty: builder.getPtrTy(), Ptr: newV);
4210
4211 if (!offsetIdx.empty())
4212 newV = builder.CreateInBoundsGEP(Ty: mapData.BaseType[i], Ptr: newV, IdxList: offsetIdx,
4213 Name: "array_offset");
4214 mapData.Pointers[i] = newV;
4215 } break;
4216 case omp::VariableCaptureKind::ByCopy: {
4217 llvm::Type *type = mapData.BaseType[i];
4218 llvm::Value *newV;
4219 if (mapData.Pointers[i]->getType()->isPointerTy())
4220 newV = builder.CreateLoad(Ty: type, Ptr: mapData.Pointers[i]);
4221 else
4222 newV = mapData.Pointers[i];
4223
4224 if (!isPtrTy) {
4225 auto curInsert = builder.saveIP();
4226 builder.restoreIP(IP: findAllocaInsertPoint(builder, moduleTranslation));
4227 auto *memTempAlloc =
4228 builder.CreateAlloca(Ty: builder.getPtrTy(), ArraySize: nullptr, Name: ".casted");
4229 builder.restoreIP(IP: curInsert);
4230
4231 builder.CreateStore(Val: newV, Ptr: memTempAlloc);
4232 newV = builder.CreateLoad(Ty: builder.getPtrTy(), Ptr: memTempAlloc);
4233 }
4234
4235 mapData.Pointers[i] = newV;
4236 mapData.BasePointers[i] = newV;
4237 } break;
4238 case omp::VariableCaptureKind::This:
4239 case omp::VariableCaptureKind::VLAType:
4240 mapData.MapClause[i]->emitOpError(message: "Unhandled capture kind");
4241 break;
4242 }
4243 }
4244 }
4245}
4246
4247// Generate all map related information and fill the combinedInfo.
4248static void genMapInfos(llvm::IRBuilderBase &builder,
4249 LLVM::ModuleTranslation &moduleTranslation,
4250 DataLayout &dl, MapInfosTy &combinedInfo,
4251 MapInfoData &mapData, bool isTargetParams = false) {
4252 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
4253 "function only supported for host device codegen");
4254
4255 // We wish to modify some of the methods in which arguments are
4256 // passed based on their capture type by the target region, this can
4257 // involve generating new loads and stores, which changes the
4258 // MLIR value to LLVM value mapping, however, we only wish to do this
4259 // locally for the current function/target and also avoid altering
4260 // ModuleTranslation, so we remap the base pointer or pointer stored
4261 // in the map infos corresponding MapInfoData, which is later accessed
4262 // by genMapInfos and createTarget to help generate the kernel and
4263 // kernel arg structure. It primarily becomes relevant in cases like
4264 // bycopy, or byref range'd arrays. In the default case, we simply
4265 // pass thee pointer byref as both basePointer and pointer.
4266 createAlteredByCaptureMap(mapData, moduleTranslation, builder);
4267
4268 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4269
4270 // We operate under the assumption that all vectors that are
4271 // required in MapInfoData are of equal lengths (either filled with
4272 // default constructed data or appropiate information) so we can
4273 // utilise the size from any component of MapInfoData, if we can't
4274 // something is missing from the initial MapInfoData construction.
4275 for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
4276 // NOTE/TODO: We currently do not support arbitrary depth record
4277 // type mapping.
4278 if (mapData.IsAMember[i])
4279 continue;
4280
4281 auto mapInfoOp = dyn_cast<omp::MapInfoOp>(mapData.MapClause[i]);
4282 if (!mapInfoOp.getMembers().empty()) {
4283 processMapWithMembersOf(moduleTranslation, builder, ompBuilder&: *ompBuilder, dl,
4284 combinedInfo, mapData, mapDataIndex: i, isTargetParams);
4285 continue;
4286 }
4287
4288 processIndividualMap(mapData, mapDataIdx: i, combinedInfo, isTargetParams);
4289 }
4290}
4291
4292static llvm::Expected<llvm::Function *>
4293emitUserDefinedMapper(Operation *declMapperOp, llvm::IRBuilderBase &builder,
4294 LLVM::ModuleTranslation &moduleTranslation,
4295 llvm::StringRef mapperFuncName);
4296
4297static llvm::Expected<llvm::Function *>
4298getOrCreateUserDefinedMapperFunc(Operation *op, llvm::IRBuilderBase &builder,
4299 LLVM::ModuleTranslation &moduleTranslation) {
4300 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
4301 "function only supported for host device codegen");
4302 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
4303 std::string mapperFuncName =
4304 moduleTranslation.getOpenMPBuilder()->createPlatformSpecificName(
4305 Parts: {"omp_mapper", declMapperOp.getSymName()});
4306
4307 if (auto *lookupFunc = moduleTranslation.lookupFunction(mapperFuncName))
4308 return lookupFunc;
4309
4310 return emitUserDefinedMapper(declMapperOp, builder, moduleTranslation,
4311 mapperFuncName);
4312}
4313
4314static llvm::Expected<llvm::Function *>
4315emitUserDefinedMapper(Operation *op, llvm::IRBuilderBase &builder,
4316 LLVM::ModuleTranslation &moduleTranslation,
4317 llvm::StringRef mapperFuncName) {
4318 assert(!moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
4319 "function only supported for host device codegen");
4320 auto declMapperOp = cast<omp::DeclareMapperOp>(op);
4321 auto declMapperInfoOp = declMapperOp.getDeclareMapperInfo();
4322 DataLayout dl = DataLayout(declMapperOp->getParentOfType<ModuleOp>());
4323 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4324 llvm::Type *varType = moduleTranslation.convertType(type: declMapperOp.getType());
4325 SmallVector<Value> mapVars = declMapperInfoOp.getMapVars();
4326
4327 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
4328
4329 // Fill up the arrays with all the mapped variables.
4330 MapInfosTy combinedInfo;
4331 auto genMapInfoCB =
4332 [&](InsertPointTy codeGenIP, llvm::Value *ptrPHI,
4333 llvm::Value *unused2) -> llvm::OpenMPIRBuilder::MapInfosOrErrorTy {
4334 builder.restoreIP(IP: codeGenIP);
4335 moduleTranslation.mapValue(declMapperOp.getSymVal(), ptrPHI);
4336 moduleTranslation.mapBlock(mlir: &declMapperOp.getRegion().front(),
4337 llvm: builder.GetInsertBlock());
4338 if (failed(moduleTranslation.convertBlock(bb&: declMapperOp.getRegion().front(),
4339 /*ignoreArguments=*/true,
4340 builder)))
4341 return llvm::make_error<PreviouslyReportedError>();
4342 MapInfoData mapData;
4343 collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl,
4344 builder);
4345 genMapInfos(builder, moduleTranslation, dl, combinedInfo, mapData);
4346
4347 // Drop the mapping that is no longer necessary so that the same region can
4348 // be processed multiple times.
4349 moduleTranslation.forgetMapping(region&: declMapperOp.getRegion());
4350 return combinedInfo;
4351 };
4352
4353 auto customMapperCB = [&](unsigned i) -> llvm::Expected<llvm::Function *> {
4354 if (!combinedInfo.Mappers[i])
4355 return nullptr;
4356 return getOrCreateUserDefinedMapperFunc(op: combinedInfo.Mappers[i], builder,
4357 moduleTranslation);
4358 };
4359
4360 llvm::Expected<llvm::Function *> newFn = ompBuilder->emitUserDefinedMapper(
4361 PrivAndGenMapInfoCB: genMapInfoCB, ElemTy: varType, FuncName: mapperFuncName, CustomMapperCB: customMapperCB);
4362 if (!newFn)
4363 return newFn.takeError();
4364 moduleTranslation.mapFunction(name: mapperFuncName, func: *newFn);
4365 return *newFn;
4366}
4367
4368static LogicalResult
4369convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder,
4370 LLVM::ModuleTranslation &moduleTranslation) {
4371 llvm::Value *ifCond = nullptr;
4372 int64_t deviceID = llvm::omp::OMP_DEVICEID_UNDEF;
4373 SmallVector<Value> mapVars;
4374 SmallVector<Value> useDevicePtrVars;
4375 SmallVector<Value> useDeviceAddrVars;
4376 llvm::omp::RuntimeFunction RTLFn;
4377 DataLayout DL = DataLayout(op->getParentOfType<ModuleOp>());
4378
4379 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4380 llvm::OpenMPIRBuilder::TargetDataInfo info(/*RequiresDevicePointerInfo=*/true,
4381 /*SeparateBeginEndCalls=*/true);
4382
4383 LogicalResult result =
4384 llvm::TypeSwitch<Operation *, LogicalResult>(op)
4385 .Case(caseFn: [&](omp::TargetDataOp dataOp) {
4386 if (failed(checkImplementationStatus(*dataOp)))
4387 return failure();
4388
4389 if (auto ifVar = dataOp.getIfExpr())
4390 ifCond = moduleTranslation.lookupValue(value: ifVar);
4391
4392 if (auto devId = dataOp.getDevice())
4393 if (auto constOp =
4394 dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
4395 if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
4396 deviceID = intAttr.getInt();
4397
4398 mapVars = dataOp.getMapVars();
4399 useDevicePtrVars = dataOp.getUseDevicePtrVars();
4400 useDeviceAddrVars = dataOp.getUseDeviceAddrVars();
4401 return success();
4402 })
4403 .Case(caseFn: [&](omp::TargetEnterDataOp enterDataOp) -> LogicalResult {
4404 if (failed(checkImplementationStatus(*enterDataOp)))
4405 return failure();
4406
4407 if (auto ifVar = enterDataOp.getIfExpr())
4408 ifCond = moduleTranslation.lookupValue(value: ifVar);
4409
4410 if (auto devId = enterDataOp.getDevice())
4411 if (auto constOp =
4412 dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
4413 if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
4414 deviceID = intAttr.getInt();
4415 RTLFn =
4416 enterDataOp.getNowait()
4417 ? llvm::omp::OMPRTL___tgt_target_data_begin_nowait_mapper
4418 : llvm::omp::OMPRTL___tgt_target_data_begin_mapper;
4419 mapVars = enterDataOp.getMapVars();
4420 info.HasNoWait = enterDataOp.getNowait();
4421 return success();
4422 })
4423 .Case(caseFn: [&](omp::TargetExitDataOp exitDataOp) -> LogicalResult {
4424 if (failed(checkImplementationStatus(*exitDataOp)))
4425 return failure();
4426
4427 if (auto ifVar = exitDataOp.getIfExpr())
4428 ifCond = moduleTranslation.lookupValue(value: ifVar);
4429
4430 if (auto devId = exitDataOp.getDevice())
4431 if (auto constOp =
4432 dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
4433 if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
4434 deviceID = intAttr.getInt();
4435
4436 RTLFn = exitDataOp.getNowait()
4437 ? llvm::omp::OMPRTL___tgt_target_data_end_nowait_mapper
4438 : llvm::omp::OMPRTL___tgt_target_data_end_mapper;
4439 mapVars = exitDataOp.getMapVars();
4440 info.HasNoWait = exitDataOp.getNowait();
4441 return success();
4442 })
4443 .Case(caseFn: [&](omp::TargetUpdateOp updateDataOp) -> LogicalResult {
4444 if (failed(checkImplementationStatus(*updateDataOp)))
4445 return failure();
4446
4447 if (auto ifVar = updateDataOp.getIfExpr())
4448 ifCond = moduleTranslation.lookupValue(value: ifVar);
4449
4450 if (auto devId = updateDataOp.getDevice())
4451 if (auto constOp =
4452 dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp()))
4453 if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
4454 deviceID = intAttr.getInt();
4455
4456 RTLFn =
4457 updateDataOp.getNowait()
4458 ? llvm::omp::OMPRTL___tgt_target_data_update_nowait_mapper
4459 : llvm::omp::OMPRTL___tgt_target_data_update_mapper;
4460 mapVars = updateDataOp.getMapVars();
4461 info.HasNoWait = updateDataOp.getNowait();
4462 return success();
4463 })
4464 .Default(defaultFn: [&](Operation *op) {
4465 llvm_unreachable("unexpected operation");
4466 return failure();
4467 });
4468
4469 if (failed(Result: result))
4470 return failure();
4471
4472 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
4473 MapInfoData mapData;
4474 collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl&: DL,
4475 builder, useDevPtrOperands: useDevicePtrVars, useDevAddrOperands: useDeviceAddrVars);
4476
4477 // Fill up the arrays with all the mapped variables.
4478 MapInfosTy combinedInfo;
4479 auto genMapInfoCB = [&](InsertPointTy codeGenIP) -> MapInfosTy & {
4480 builder.restoreIP(IP: codeGenIP);
4481 genMapInfos(builder, moduleTranslation, dl&: DL, combinedInfo, mapData);
4482 return combinedInfo;
4483 };
4484
4485 // Define a lambda to apply mappings between use_device_addr and
4486 // use_device_ptr base pointers, and their associated block arguments.
4487 auto mapUseDevice =
4488 [&moduleTranslation](
4489 llvm::OpenMPIRBuilder::DeviceInfoTy type,
4490 llvm::ArrayRef<BlockArgument> blockArgs,
4491 llvm::SmallVectorImpl<Value> &useDeviceVars, MapInfoData &mapInfoData,
4492 llvm::function_ref<llvm::Value *(llvm::Value *)> mapper = nullptr) {
4493 for (auto [arg, useDevVar] :
4494 llvm::zip_equal(t&: blockArgs, u&: useDeviceVars)) {
4495
4496 auto getMapBasePtr = [](omp::MapInfoOp mapInfoOp) {
4497 return mapInfoOp.getVarPtrPtr() ? mapInfoOp.getVarPtrPtr()
4498 : mapInfoOp.getVarPtr();
4499 };
4500
4501 auto useDevMap = cast<omp::MapInfoOp>(useDevVar.getDefiningOp());
4502 for (auto [mapClause, devicePointer, basePointer] : llvm::zip_equal(
4503 t&: mapInfoData.MapClause, u&: mapInfoData.DevicePointers,
4504 args&: mapInfoData.BasePointers)) {
4505 auto mapOp = cast<omp::MapInfoOp>(mapClause);
4506 if (getMapBasePtr(mapOp) != getMapBasePtr(useDevMap) ||
4507 devicePointer != type)
4508 continue;
4509
4510 if (llvm::Value *devPtrInfoMap =
4511 mapper ? mapper(basePointer) : basePointer) {
4512 moduleTranslation.mapValue(mlir: arg, llvm: devPtrInfoMap);
4513 break;
4514 }
4515 }
4516 }
4517 };
4518
4519 using BodyGenTy = llvm::OpenMPIRBuilder::BodyGenTy;
4520 auto bodyGenCB = [&](InsertPointTy codeGenIP, BodyGenTy bodyGenType)
4521 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
4522 builder.restoreIP(IP: codeGenIP);
4523 assert(isa<omp::TargetDataOp>(op) &&
4524 "BodyGen requested for non TargetDataOp");
4525 auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(op);
4526 Region &region = cast<omp::TargetDataOp>(op).getRegion();
4527 switch (bodyGenType) {
4528 case BodyGenTy::Priv:
4529 // Check if any device ptr/addr info is available
4530 if (!info.DevicePtrInfoMap.empty()) {
4531 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
4532 blockArgIface.getUseDeviceAddrBlockArgs(),
4533 useDeviceAddrVars, mapData,
4534 [&](llvm::Value *basePointer) -> llvm::Value * {
4535 if (!info.DevicePtrInfoMap[basePointer].second)
4536 return nullptr;
4537 return builder.CreateLoad(
4538 Ty: builder.getPtrTy(),
4539 Ptr: info.DevicePtrInfoMap[basePointer].second);
4540 });
4541 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
4542 blockArgIface.getUseDevicePtrBlockArgs(), useDevicePtrVars,
4543 mapData, [&](llvm::Value *basePointer) {
4544 return info.DevicePtrInfoMap[basePointer].second;
4545 });
4546
4547 if (failed(Result: inlineConvertOmpRegions(region, blockName: "omp.data.region", builder,
4548 moduleTranslation)))
4549 return llvm::make_error<PreviouslyReportedError>();
4550 }
4551 break;
4552 case BodyGenTy::DupNoPriv:
4553 // We must always restoreIP regardless of doing anything the caller
4554 // does not restore it, leading to incorrect (no) branch generation.
4555 builder.restoreIP(IP: codeGenIP);
4556 break;
4557 case BodyGenTy::NoPriv:
4558 // If device info is available then region has already been generated
4559 if (info.DevicePtrInfoMap.empty()) {
4560 // For device pass, if use_device_ptr(addr) mappings were present,
4561 // we need to link them here before codegen.
4562 if (ompBuilder->Config.IsTargetDevice.value_or(u: false)) {
4563 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Address,
4564 blockArgIface.getUseDeviceAddrBlockArgs(),
4565 useDeviceAddrVars, mapData);
4566 mapUseDevice(llvm::OpenMPIRBuilder::DeviceInfoTy::Pointer,
4567 blockArgIface.getUseDevicePtrBlockArgs(),
4568 useDevicePtrVars, mapData);
4569 }
4570
4571 if (failed(Result: inlineConvertOmpRegions(region, blockName: "omp.data.region", builder,
4572 moduleTranslation)))
4573 return llvm::make_error<PreviouslyReportedError>();
4574 }
4575 break;
4576 }
4577 return builder.saveIP();
4578 };
4579
4580 auto customMapperCB =
4581 [&](unsigned int i) -> llvm::Expected<llvm::Function *> {
4582 if (!combinedInfo.Mappers[i])
4583 return nullptr;
4584 info.HasMapper = true;
4585 return getOrCreateUserDefinedMapperFunc(op: combinedInfo.Mappers[i], builder,
4586 moduleTranslation);
4587 };
4588
4589 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4590 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4591 findAllocaInsertPoint(builder, moduleTranslation);
4592 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP = [&]() {
4593 if (isa<omp::TargetDataOp>(op))
4594 return ompBuilder->createTargetData(Loc: ompLoc, AllocaIP: allocaIP, CodeGenIP: builder.saveIP(),
4595 DeviceID: builder.getInt64(C: deviceID), IfCond: ifCond,
4596 Info&: info, GenMapInfoCB: genMapInfoCB, CustomMapperCB: customMapperCB,
4597 /*MapperFunc=*/nullptr, BodyGenCB: bodyGenCB,
4598 /*DeviceAddrCB=*/nullptr);
4599 return ompBuilder->createTargetData(
4600 Loc: ompLoc, AllocaIP: allocaIP, CodeGenIP: builder.saveIP(), DeviceID: builder.getInt64(C: deviceID), IfCond: ifCond,
4601 Info&: info, GenMapInfoCB: genMapInfoCB, CustomMapperCB: customMapperCB, MapperFunc: &RTLFn);
4602 }();
4603
4604 if (failed(Result: handleError(result&: afterIP, op&: *op)))
4605 return failure();
4606
4607 builder.restoreIP(IP: *afterIP);
4608 return success();
4609}
4610
4611static LogicalResult
4612convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder,
4613 LLVM::ModuleTranslation &moduleTranslation) {
4614 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4615 auto distributeOp = cast<omp::DistributeOp>(opInst);
4616 if (failed(Result: checkImplementationStatus(op&: opInst)))
4617 return failure();
4618
4619 /// Process teams op reduction in distribute if the reduction is contained in
4620 /// the distribute op.
4621 omp::TeamsOp teamsOp = opInst.getParentOfType<omp::TeamsOp>();
4622 bool doDistributeReduction =
4623 teamsOp ? teamsReductionContainedInDistribute(teamsOp) : false;
4624
4625 DenseMap<Value, llvm::Value *> reductionVariableMap;
4626 unsigned numReductionVars = teamsOp ? teamsOp.getNumReductionVars() : 0;
4627 SmallVector<omp::DeclareReductionOp> reductionDecls;
4628 SmallVector<llvm::Value *> privateReductionVariables(numReductionVars);
4629 llvm::ArrayRef<bool> isByRef;
4630
4631 if (doDistributeReduction) {
4632 isByRef = getIsByRef(teamsOp.getReductionByref());
4633 assert(isByRef.size() == teamsOp.getNumReductionVars());
4634
4635 collectReductionDecls(teamsOp, reductionDecls);
4636 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4637 findAllocaInsertPoint(builder, moduleTranslation);
4638
4639 MutableArrayRef<BlockArgument> reductionArgs =
4640 llvm::cast<omp::BlockArgOpenMPOpInterface>(*teamsOp)
4641 .getReductionBlockArgs();
4642
4643 if (failed(allocAndInitializeReductionVars(
4644 teamsOp, reductionArgs, builder, moduleTranslation, allocaIP,
4645 reductionDecls, privateReductionVariables, reductionVariableMap,
4646 isByRef)))
4647 return failure();
4648 }
4649
4650 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
4651 auto bodyGenCB = [&](InsertPointTy allocaIP,
4652 InsertPointTy codeGenIP) -> llvm::Error {
4653 // Save the alloca insertion point on ModuleTranslation stack for use in
4654 // nested regions.
4655 LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame(
4656 moduleTranslation, allocaIP);
4657
4658 // DistributeOp has only one region associated with it.
4659 builder.restoreIP(IP: codeGenIP);
4660 PrivateVarsInfo privVarsInfo(distributeOp);
4661
4662 llvm::Expected<llvm::BasicBlock *> afterAllocas =
4663 allocatePrivateVars(builder, moduleTranslation, privateVarsInfo&: privVarsInfo, allocaIP);
4664 if (handleError(result&: afterAllocas, op&: opInst).failed())
4665 return llvm::make_error<PreviouslyReportedError>();
4666
4667 if (handleError(error: initPrivateVars(builder, moduleTranslation, privateVarsInfo&: privVarsInfo),
4668 op&: opInst)
4669 .failed())
4670 return llvm::make_error<PreviouslyReportedError>();
4671
4672 if (failed(copyFirstPrivateVars(
4673 distributeOp, builder, moduleTranslation, privVarsInfo.mlirVars,
4674 privVarsInfo.llvmVars, privVarsInfo.privatizers,
4675 distributeOp.getPrivateNeedsBarrier())))
4676 return llvm::make_error<PreviouslyReportedError>();
4677
4678 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4679 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4680 llvm::Expected<llvm::BasicBlock *> regionBlock =
4681 convertOmpOpRegions(distributeOp.getRegion(), "omp.distribute.region",
4682 builder, moduleTranslation);
4683 if (!regionBlock)
4684 return regionBlock.takeError();
4685 builder.SetInsertPoint(TheBB: *regionBlock, IP: (*regionBlock)->begin());
4686
4687 // Skip applying a workshare loop below when translating 'distribute
4688 // parallel do' (it's been already handled by this point while translating
4689 // the nested omp.wsloop).
4690 if (!isa_and_present<omp::WsloopOp>(distributeOp.getNestedWrapper())) {
4691 // TODO: Add support for clauses which are valid for DISTRIBUTE
4692 // constructs. Static schedule is the default.
4693 auto schedule = omp::ClauseScheduleKind::Static;
4694 bool isOrdered = false;
4695 std::optional<omp::ScheduleModifier> scheduleMod;
4696 bool isSimd = false;
4697 llvm::omp::WorksharingLoopType workshareLoopType =
4698 llvm::omp::WorksharingLoopType::DistributeStaticLoop;
4699 bool loopNeedsBarrier = false;
4700 llvm::Value *chunk = nullptr;
4701
4702 llvm::CanonicalLoopInfo *loopInfo =
4703 findCurrentLoopInfo(moduleTranslation);
4704 llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
4705 ompBuilder->applyWorkshareLoop(
4706 ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
4707 convertToScheduleKind(schedule), chunk, isSimd,
4708 scheduleMod == omp::ScheduleModifier::monotonic,
4709 scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
4710 workshareLoopType);
4711
4712 if (!wsloopIP)
4713 return wsloopIP.takeError();
4714 }
4715
4716 if (failed(cleanupPrivateVars(builder, moduleTranslation,
4717 distributeOp.getLoc(), privVarsInfo.llvmVars,
4718 privVarsInfo.privatizers)))
4719 return llvm::make_error<PreviouslyReportedError>();
4720
4721 return llvm::Error::success();
4722 };
4723
4724 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
4725 findAllocaInsertPoint(builder, moduleTranslation);
4726 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
4727 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
4728 ompBuilder->createDistribute(Loc: ompLoc, AllocaIP: allocaIP, BodyGenCB: bodyGenCB);
4729
4730 if (failed(Result: handleError(result&: afterIP, op&: opInst)))
4731 return failure();
4732
4733 builder.restoreIP(IP: *afterIP);
4734
4735 if (doDistributeReduction) {
4736 // Process the reductions if required.
4737 return createReductionsAndCleanup(
4738 teamsOp, builder, moduleTranslation, allocaIP, reductionDecls,
4739 privateReductionVariables, isByRef,
4740 /*isNoWait*/ false, /*isTeamsReduction*/ true);
4741 }
4742 return success();
4743}
4744
4745/// Lowers the FlagsAttr which is applied to the module on the device
4746/// pass when offloading, this attribute contains OpenMP RTL globals that can
4747/// be passed as flags to the frontend, otherwise they are set to default
4748LogicalResult convertFlagsAttr(Operation *op, mlir::omp::FlagsAttr attribute,
4749 LLVM::ModuleTranslation &moduleTranslation) {
4750 if (!cast<mlir::ModuleOp>(op))
4751 return failure();
4752
4753 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
4754
4755 ompBuilder->M.addModuleFlag(llvm::Module::Max, "openmp-device",
4756 attribute.getOpenmpDeviceVersion());
4757
4758 if (attribute.getNoGpuLib())
4759 return success();
4760
4761 ompBuilder->createGlobalFlag(
4762 Value: attribute.getDebugKind() /*LangOpts().OpenMPTargetDebug*/,
4763 Name: "__omp_rtl_debug_kind");
4764 ompBuilder->createGlobalFlag(
4765 Value: attribute
4766 .getAssumeTeamsOversubscription() /*LangOpts().OpenMPTeamSubscription*/
4767 ,
4768 Name: "__omp_rtl_assume_teams_oversubscription");
4769 ompBuilder->createGlobalFlag(
4770 Value: attribute
4771 .getAssumeThreadsOversubscription() /*LangOpts().OpenMPThreadSubscription*/
4772 ,
4773 Name: "__omp_rtl_assume_threads_oversubscription");
4774 ompBuilder->createGlobalFlag(
4775 Value: attribute.getAssumeNoThreadState() /*LangOpts().OpenMPNoThreadState*/,
4776 Name: "__omp_rtl_assume_no_thread_state");
4777 ompBuilder->createGlobalFlag(
4778 Value: attribute
4779 .getAssumeNoNestedParallelism() /*LangOpts().OpenMPNoNestedParallelism*/
4780 ,
4781 Name: "__omp_rtl_assume_no_nested_parallelism");
4782 return success();
4783}
4784
4785static void getTargetEntryUniqueInfo(llvm::TargetRegionEntryInfo &targetInfo,
4786 omp::TargetOp targetOp,
4787 llvm::StringRef parentName = "") {
4788 auto fileLoc = targetOp.getLoc()->findInstanceOf<FileLineColLoc>();
4789
4790 assert(fileLoc && "No file found from location");
4791 StringRef fileName = fileLoc.getFilename().getValue();
4792
4793 llvm::sys::fs::UniqueID id;
4794 uint64_t line = fileLoc.getLine();
4795 if (auto ec = llvm::sys::fs::getUniqueID(fileName, id)) {
4796 size_t fileHash = llvm::hash_value(arg: fileName.str());
4797 size_t deviceId = 0xdeadf17e;
4798 targetInfo =
4799 llvm::TargetRegionEntryInfo(parentName, deviceId, fileHash, line);
4800 } else {
4801 targetInfo = llvm::TargetRegionEntryInfo(parentName, id.getDevice(),
4802 id.getFile(), line);
4803 }
4804}
4805
4806static void
4807handleDeclareTargetMapVar(MapInfoData &mapData,
4808 LLVM::ModuleTranslation &moduleTranslation,
4809 llvm::IRBuilderBase &builder, llvm::Function *func) {
4810 assert(moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice() &&
4811 "function only supported for target device codegen");
4812 for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
4813 // In the case of declare target mapped variables, the basePointer is
4814 // the reference pointer generated by the convertDeclareTargetAttr
4815 // method. Whereas the kernelValue is the original variable, so for
4816 // the device we must replace all uses of this original global variable
4817 // (stored in kernelValue) with the reference pointer (stored in
4818 // basePointer for declare target mapped variables), as for device the
4819 // data is mapped into this reference pointer and should be loaded
4820 // from it, the original variable is discarded. On host both exist and
4821 // metadata is generated (elsewhere in the convertDeclareTargetAttr)
4822 // function to link the two variables in the runtime and then both the
4823 // reference pointer and the pointer are assigned in the kernel argument
4824 // structure for the host.
4825 if (mapData.IsDeclareTarget[i]) {
4826 // If the original map value is a constant, then we have to make sure all
4827 // of it's uses within the current kernel/function that we are going to
4828 // rewrite are converted to instructions, as we will be altering the old
4829 // use (OriginalValue) from a constant to an instruction, which will be
4830 // illegal and ICE the compiler if the user is a constant expression of
4831 // some kind e.g. a constant GEP.
4832 if (auto *constant = dyn_cast<llvm::Constant>(Val: mapData.OriginalValue[i]))
4833 convertUsersOfConstantsToInstructions(Consts: constant, RestrictToFunc: func, RemoveDeadConstants: false);
4834
4835 // The users iterator will get invalidated if we modify an element,
4836 // so we populate this vector of uses to alter each user on an
4837 // individual basis to emit its own load (rather than one load for
4838 // all).
4839 llvm::SmallVector<llvm::User *> userVec;
4840 for (llvm::User *user : mapData.OriginalValue[i]->users())
4841 userVec.push_back(Elt: user);
4842
4843 for (llvm::User *user : userVec) {
4844 if (auto *insn = dyn_cast<llvm::Instruction>(Val: user)) {
4845 if (insn->getFunction() == func) {
4846 auto *load = builder.CreateLoad(Ty: mapData.BasePointers[i]->getType(),
4847 Ptr: mapData.BasePointers[i]);
4848 load->moveBefore(InsertPos: insn->getIterator());
4849 user->replaceUsesOfWith(From: mapData.OriginalValue[i], To: load);
4850 }
4851 }
4852 }
4853 }
4854 }
4855}
4856
4857// The createDeviceArgumentAccessor function generates
4858// instructions for retrieving (acessing) kernel
4859// arguments inside of the device kernel for use by
4860// the kernel. This enables different semantics such as
4861// the creation of temporary copies of data allowing
4862// semantics like read-only/no host write back kernel
4863// arguments.
4864//
4865// This currently implements a very light version of Clang's
4866// EmitParmDecl's handling of direct argument handling as well
4867// as a portion of the argument access generation based on
4868// capture types found at the end of emitOutlinedFunctionPrologue
4869// in Clang. The indirect path handling of EmitParmDecl's may be
4870// required for future work, but a direct 1-to-1 copy doesn't seem
4871// possible as the logic is rather scattered throughout Clang's
4872// lowering and perhaps we wish to deviate slightly.
4873//
4874// \param mapData - A container containing vectors of information
4875// corresponding to the input argument, which should have a
4876// corresponding entry in the MapInfoData containers
4877// OrigialValue's.
4878// \param arg - This is the generated kernel function argument that
4879// corresponds to the passed in input argument. We generated different
4880// accesses of this Argument, based on capture type and other Input
4881// related information.
4882// \param input - This is the host side value that will be passed to
4883// the kernel i.e. the kernel input, we rewrite all uses of this within
4884// the kernel (as we generate the kernel body based on the target's region
4885// which maintians references to the original input) to the retVal argument
4886// apon exit of this function inside of the OMPIRBuilder. This interlinks
4887// the kernel argument to future uses of it in the function providing
4888// appropriate "glue" instructions inbetween.
4889// \param retVal - This is the value that all uses of input inside of the
4890// kernel will be re-written to, the goal of this function is to generate
4891// an appropriate location for the kernel argument to be accessed from,
4892// e.g. ByRef will result in a temporary allocation location and then
4893// a store of the kernel argument into this allocated memory which
4894// will then be loaded from, ByCopy will use the allocated memory
4895// directly.
4896static llvm::IRBuilderBase::InsertPoint
4897createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
4898 llvm::Value *input, llvm::Value *&retVal,
4899 llvm::IRBuilderBase &builder,
4900 llvm::OpenMPIRBuilder &ompBuilder,
4901 LLVM::ModuleTranslation &moduleTranslation,
4902 llvm::IRBuilderBase::InsertPoint allocaIP,
4903 llvm::IRBuilderBase::InsertPoint codeGenIP) {
4904 assert(ompBuilder.Config.isTargetDevice() &&
4905 "function only supported for target device codegen");
4906 builder.restoreIP(IP: allocaIP);
4907
4908 omp::VariableCaptureKind capture = omp::VariableCaptureKind::ByRef;
4909 LLVM::TypeToLLVMIRTranslator typeToLLVMIRTranslator(
4910 ompBuilder.M.getContext());
4911 unsigned alignmentValue = 0;
4912 // Find the associated MapInfoData entry for the current input
4913 for (size_t i = 0; i < mapData.MapClause.size(); ++i)
4914 if (mapData.OriginalValue[i] == input) {
4915 auto mapOp = cast<omp::MapInfoOp>(mapData.MapClause[i]);
4916 capture = mapOp.getMapCaptureType();
4917 // Get information of alignment of mapped object
4918 alignmentValue = typeToLLVMIRTranslator.getPreferredAlignment(
4919 type: mapOp.getVarType(), layout: ompBuilder.M.getDataLayout());
4920 break;
4921 }
4922
4923 unsigned int allocaAS = ompBuilder.M.getDataLayout().getAllocaAddrSpace();
4924 unsigned int defaultAS =
4925 ompBuilder.M.getDataLayout().getProgramAddressSpace();
4926
4927 // Create the alloca for the argument the current point.
4928 llvm::Value *v = builder.CreateAlloca(Ty: arg.getType(), AddrSpace: allocaAS);
4929
4930 if (allocaAS != defaultAS && arg.getType()->isPointerTy())
4931 v = builder.CreateAddrSpaceCast(V: v, DestTy: builder.getPtrTy(AddrSpace: defaultAS));
4932
4933 builder.CreateStore(Val: &arg, Ptr: v);
4934
4935 builder.restoreIP(IP: codeGenIP);
4936
4937 switch (capture) {
4938 case omp::VariableCaptureKind::ByCopy: {
4939 retVal = v;
4940 break;
4941 }
4942 case omp::VariableCaptureKind::ByRef: {
4943 llvm::LoadInst *loadInst = builder.CreateAlignedLoad(
4944 Ty: v->getType(), Ptr: v,
4945 Align: ompBuilder.M.getDataLayout().getPrefTypeAlign(Ty: v->getType()));
4946 // CreateAlignedLoad function creates similar LLVM IR:
4947 // %res = load ptr, ptr %input, align 8
4948 // This LLVM IR does not contain information about alignment
4949 // of the loaded value. We need to add !align metadata to unblock
4950 // optimizer. The existence of the !align metadata on the instruction
4951 // tells the optimizer that the value loaded is known to be aligned to
4952 // a boundary specified by the integer value in the metadata node.
4953 // Example:
4954 // %res = load ptr, ptr %input, align 8, !align !align_md_node
4955 // ^ ^
4956 // | |
4957 // alignment of %input address |
4958 // |
4959 // alignment of %res object
4960 if (v->getType()->isPointerTy() && alignmentValue) {
4961 llvm::MDBuilder MDB(builder.getContext());
4962 loadInst->setMetadata(
4963 KindID: llvm::LLVMContext::MD_align,
4964 Node: llvm::MDNode::get(Context&: builder.getContext(),
4965 MDs: MDB.createConstant(C: llvm::ConstantInt::get(
4966 Ty: llvm::Type::getInt64Ty(C&: builder.getContext()),
4967 V: alignmentValue))));
4968 }
4969 retVal = loadInst;
4970
4971 break;
4972 }
4973 case omp::VariableCaptureKind::This:
4974 case omp::VariableCaptureKind::VLAType:
4975 // TODO: Consider returning error to use standard reporting for
4976 // unimplemented features.
4977 assert(false && "Currently unsupported capture kind");
4978 break;
4979 }
4980
4981 return builder.saveIP();
4982}
4983
4984/// Follow uses of `host_eval`-defined block arguments of the given `omp.target`
4985/// operation and populate output variables with their corresponding host value
4986/// (i.e. operand evaluated outside of the target region), based on their uses
4987/// inside of the target region.
4988///
4989/// Loop bounds and steps are only optionally populated, if output vectors are
4990/// provided.
4991static void
4992extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
4993 Value &numTeamsLower, Value &numTeamsUpper,
4994 Value &threadLimit,
4995 llvm::SmallVectorImpl<Value> *lowerBounds = nullptr,
4996 llvm::SmallVectorImpl<Value> *upperBounds = nullptr,
4997 llvm::SmallVectorImpl<Value> *steps = nullptr) {
4998 auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
4999 for (auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
5000 blockArgIface.getHostEvalBlockArgs())) {
5001 Value hostEvalVar = std::get<0>(item), blockArg = std::get<1>(item);
5002
5003 for (Operation *user : blockArg.getUsers()) {
5004 llvm::TypeSwitch<Operation *>(user)
5005 .Case([&](omp::TeamsOp teamsOp) {
5006 if (teamsOp.getNumTeamsLower() == blockArg)
5007 numTeamsLower = hostEvalVar;
5008 else if (teamsOp.getNumTeamsUpper() == blockArg)
5009 numTeamsUpper = hostEvalVar;
5010 else if (teamsOp.getThreadLimit() == blockArg)
5011 threadLimit = hostEvalVar;
5012 else
5013 llvm_unreachable("unsupported host_eval use");
5014 })
5015 .Case([&](omp::ParallelOp parallelOp) {
5016 if (parallelOp.getNumThreads() == blockArg)
5017 numThreads = hostEvalVar;
5018 else
5019 llvm_unreachable("unsupported host_eval use");
5020 })
5021 .Case([&](omp::LoopNestOp loopOp) {
5022 auto processBounds =
5023 [&](OperandRange opBounds,
5024 llvm::SmallVectorImpl<Value> *outBounds) -> bool {
5025 bool found = false;
5026 for (auto [i, lb] : llvm::enumerate(opBounds)) {
5027 if (lb == blockArg) {
5028 found = true;
5029 if (outBounds)
5030 (*outBounds)[i] = hostEvalVar;
5031 }
5032 }
5033 return found;
5034 };
5035 bool found =
5036 processBounds(loopOp.getLoopLowerBounds(), lowerBounds);
5037 found = processBounds(loopOp.getLoopUpperBounds(), upperBounds) ||
5038 found;
5039 found = processBounds(loopOp.getLoopSteps(), steps) || found;
5040 (void)found;
5041 assert(found && "unsupported host_eval use");
5042 })
5043 .Default([](Operation *) {
5044 llvm_unreachable("unsupported host_eval use");
5045 });
5046 }
5047 }
5048}
5049
5050/// If \p op is of the given type parameter, return it casted to that type.
5051/// Otherwise, if its immediate parent operation (or some other higher-level
5052/// parent, if \p immediateParent is false) is of that type, return that parent
5053/// casted to the given type.
5054///
5055/// If \p op is \c null or neither it or its parent(s) are of the specified
5056/// type, return a \c null operation.
5057template <typename OpTy>
5058static OpTy castOrGetParentOfType(Operation *op, bool immediateParent = false) {
5059 if (!op)
5060 return OpTy();
5061
5062 if (OpTy casted = dyn_cast<OpTy>(op))
5063 return casted;
5064
5065 if (immediateParent)
5066 return dyn_cast_if_present<OpTy>(op->getParentOp());
5067
5068 return op->getParentOfType<OpTy>();
5069}
5070
5071/// If the given \p value is defined by an \c llvm.mlir.constant operation and
5072/// it is of an integer type, return its value.
5073static std::optional<int64_t> extractConstInteger(Value value) {
5074 if (!value)
5075 return std::nullopt;
5076
5077 if (auto constOp =
5078 dyn_cast_if_present<LLVM::ConstantOp>(value.getDefiningOp()))
5079 if (auto constAttr = dyn_cast<IntegerAttr>(constOp.getValue()))
5080 return constAttr.getInt();
5081
5082 return std::nullopt;
5083}
5084
5085static uint64_t getTypeByteSize(mlir::Type type, const DataLayout &dl) {
5086 uint64_t sizeInBits = dl.getTypeSizeInBits(t: type);
5087 uint64_t sizeInBytes = sizeInBits / 8;
5088 return sizeInBytes;
5089}
5090
5091template <typename OpTy>
5092static uint64_t getReductionDataSize(OpTy &op) {
5093 if (op.getNumReductionVars() > 0) {
5094 SmallVector<omp::DeclareReductionOp> reductions;
5095 collectReductionDecls(op, reductions);
5096
5097 llvm::SmallVector<mlir::Type> members;
5098 members.reserve(N: reductions.size());
5099 for (omp::DeclareReductionOp &red : reductions)
5100 members.push_back(red.getType());
5101 Operation *opp = op.getOperation();
5102 auto structType = mlir::LLVM::LLVMStructType::getLiteral(
5103 opp->getContext(), members, /*isPacked=*/false);
5104 DataLayout dl = DataLayout(opp->getParentOfType<ModuleOp>());
5105 return getTypeByteSize(structType, dl);
5106 }
5107 return 0;
5108}
5109
5110/// Populate default `MinTeams`, `MaxTeams` and `MaxThreads` to their default
5111/// values as stated by the corresponding clauses, if constant.
5112///
5113/// These default values must be set before the creation of the outlined LLVM
5114/// function for the target region, so that they can be used to initialize the
5115/// corresponding global `ConfigurationEnvironmentTy` structure.
5116static void
5117initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
5118 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &attrs,
5119 bool isTargetDevice, bool isGPU) {
5120 // TODO: Handle constant 'if' clauses.
5121
5122 Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
5123 if (!isTargetDevice) {
5124 extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
5125 threadLimit);
5126 } else {
5127 // In the target device, values for these clauses are not passed as
5128 // host_eval, but instead evaluated prior to entry to the region. This
5129 // ensures values are mapped and available inside of the target region.
5130 if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
5131 numTeamsLower = teamsOp.getNumTeamsLower();
5132 numTeamsUpper = teamsOp.getNumTeamsUpper();
5133 threadLimit = teamsOp.getThreadLimit();
5134 }
5135
5136 if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp))
5137 numThreads = parallelOp.getNumThreads();
5138 }
5139
5140 // Handle clauses impacting the number of teams.
5141
5142 int32_t minTeamsVal = 1, maxTeamsVal = -1;
5143 if (castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
5144 // TODO: Use `hostNumTeamsLower` to initialize `minTeamsVal`. For now, match
5145 // clang and set min and max to the same value.
5146 if (numTeamsUpper) {
5147 if (auto val = extractConstInteger(value: numTeamsUpper))
5148 minTeamsVal = maxTeamsVal = *val;
5149 } else {
5150 minTeamsVal = maxTeamsVal = 0;
5151 }
5152 } else if (castOrGetParentOfType<omp::ParallelOp>(capturedOp,
5153 /*immediateParent=*/true) ||
5154 castOrGetParentOfType<omp::SimdOp>(capturedOp,
5155 /*immediateParent=*/true)) {
5156 minTeamsVal = maxTeamsVal = 1;
5157 } else {
5158 minTeamsVal = maxTeamsVal = -1;
5159 }
5160
5161 // Handle clauses impacting the number of threads.
5162
5163 auto setMaxValueFromClause = [](Value clauseValue, int32_t &result) {
5164 if (!clauseValue)
5165 return;
5166
5167 if (auto val = extractConstInteger(value: clauseValue))
5168 result = *val;
5169
5170 // Found an applicable clause, so it's not undefined. Mark as unknown
5171 // because it's not constant.
5172 if (result < 0)
5173 result = 0;
5174 };
5175
5176 // Extract 'thread_limit' clause from 'target' and 'teams' directives.
5177 int32_t targetThreadLimitVal = -1, teamsThreadLimitVal = -1;
5178 setMaxValueFromClause(targetOp.getThreadLimit(), targetThreadLimitVal);
5179 setMaxValueFromClause(threadLimit, teamsThreadLimitVal);
5180
5181 // Extract 'max_threads' clause from 'parallel' or set to 1 if it's SIMD.
5182 int32_t maxThreadsVal = -1;
5183 if (castOrGetParentOfType<omp::ParallelOp>(capturedOp))
5184 setMaxValueFromClause(numThreads, maxThreadsVal);
5185 else if (castOrGetParentOfType<omp::SimdOp>(capturedOp,
5186 /*immediateParent=*/true))
5187 maxThreadsVal = 1;
5188
5189 // For max values, < 0 means unset, == 0 means set but unknown. Select the
5190 // minimum value between 'max_threads' and 'thread_limit' clauses that were
5191 // set.
5192 int32_t combinedMaxThreadsVal = targetThreadLimitVal;
5193 if (combinedMaxThreadsVal < 0 ||
5194 (teamsThreadLimitVal >= 0 && teamsThreadLimitVal < combinedMaxThreadsVal))
5195 combinedMaxThreadsVal = teamsThreadLimitVal;
5196
5197 if (combinedMaxThreadsVal < 0 ||
5198 (maxThreadsVal >= 0 && maxThreadsVal < combinedMaxThreadsVal))
5199 combinedMaxThreadsVal = maxThreadsVal;
5200
5201 int32_t reductionDataSize = 0;
5202 if (isGPU && capturedOp) {
5203 if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp))
5204 reductionDataSize = getReductionDataSize(teamsOp);
5205 }
5206
5207 // Update kernel bounds structure for the `OpenMPIRBuilder` to use.
5208 omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp);
5209 assert(
5210 omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic |
5211 omp::TargetRegionFlags::spmd) &&
5212 "invalid kernel flags");
5213 attrs.ExecFlags =
5214 omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic)
5215 ? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
5216 ? llvm::omp::OMP_TGT_EXEC_MODE_GENERIC_SPMD
5217 : llvm::omp::OMP_TGT_EXEC_MODE_GENERIC
5218 : llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
5219 attrs.MinTeams = minTeamsVal;
5220 attrs.MaxTeams.front() = maxTeamsVal;
5221 attrs.MinThreads = 1;
5222 attrs.MaxThreads.front() = combinedMaxThreadsVal;
5223 attrs.ReductionDataSize = reductionDataSize;
5224 // TODO: Allow modified buffer length similar to
5225 // fopenmp-cuda-teams-reduction-recs-num flag in clang.
5226 if (attrs.ReductionDataSize != 0)
5227 attrs.ReductionBufferLength = 1024;
5228}
5229
5230/// Gather LLVM runtime values for all clauses evaluated in the host that are
5231/// passed to the kernel invocation.
5232///
5233/// This function must be called only when compiling for the host. Also, it will
5234/// only provide correct results if it's called after the body of \c targetOp
5235/// has been fully generated.
5236static void
5237initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
5238 LLVM::ModuleTranslation &moduleTranslation,
5239 omp::TargetOp targetOp, Operation *capturedOp,
5240 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs &attrs) {
5241 omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(capturedOp);
5242 unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
5243
5244 Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
5245 llvm::SmallVector<Value> lowerBounds(numLoops), upperBounds(numLoops),
5246 steps(numLoops);
5247 extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
5248 teamsThreadLimit, &lowerBounds, &upperBounds, &steps);
5249
5250 // TODO: Handle constant 'if' clauses.
5251 if (Value targetThreadLimit = targetOp.getThreadLimit())
5252 attrs.TargetThreadLimit.front() =
5253 moduleTranslation.lookupValue(value: targetThreadLimit);
5254
5255 if (numTeamsLower)
5256 attrs.MinTeams = moduleTranslation.lookupValue(value: numTeamsLower);
5257
5258 if (numTeamsUpper)
5259 attrs.MaxTeams.front() = moduleTranslation.lookupValue(value: numTeamsUpper);
5260
5261 if (teamsThreadLimit)
5262 attrs.TeamsThreadLimit.front() =
5263 moduleTranslation.lookupValue(value: teamsThreadLimit);
5264
5265 if (numThreads)
5266 attrs.MaxThreads = moduleTranslation.lookupValue(value: numThreads);
5267
5268 if (omp::bitEnumContainsAny(targetOp.getKernelExecFlags(capturedOp),
5269 omp::TargetRegionFlags::trip_count)) {
5270 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5271 attrs.LoopTripCount = nullptr;
5272
5273 // To calculate the trip count, we multiply together the trip counts of
5274 // every collapsed canonical loop. We don't need to create the loop nests
5275 // here, since we're only interested in the trip count.
5276 for (auto [loopLower, loopUpper, loopStep] :
5277 llvm::zip_equal(lowerBounds, upperBounds, steps)) {
5278 llvm::Value *lowerBound = moduleTranslation.lookupValue(loopLower);
5279 llvm::Value *upperBound = moduleTranslation.lookupValue(loopUpper);
5280 llvm::Value *step = moduleTranslation.lookupValue(loopStep);
5281
5282 llvm::OpenMPIRBuilder::LocationDescription loc(builder);
5283 llvm::Value *tripCount = ompBuilder->calculateCanonicalLoopTripCount(
5284 loc, lowerBound, upperBound, step, /*IsSigned=*/true,
5285 loopOp.getLoopInclusive());
5286
5287 if (!attrs.LoopTripCount) {
5288 attrs.LoopTripCount = tripCount;
5289 continue;
5290 }
5291
5292 // TODO: Enable UndefinedSanitizer to diagnose an overflow here.
5293 attrs.LoopTripCount = builder.CreateMul(attrs.LoopTripCount, tripCount,
5294 {}, /*HasNUW=*/true);
5295 }
5296 }
5297}
5298
5299static LogicalResult
5300convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
5301 LLVM::ModuleTranslation &moduleTranslation) {
5302 auto targetOp = cast<omp::TargetOp>(opInst);
5303 if (failed(Result: checkImplementationStatus(op&: opInst)))
5304 return failure();
5305
5306 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5307 bool isTargetDevice = ompBuilder->Config.isTargetDevice();
5308 bool isGPU = ompBuilder->Config.isGPU();
5309
5310 auto parentFn = opInst.getParentOfType<LLVM::LLVMFuncOp>();
5311 auto argIface = cast<omp::BlockArgOpenMPOpInterface>(opInst);
5312 auto &targetRegion = targetOp.getRegion();
5313 // Holds the private vars that have been mapped along with the block argument
5314 // that corresponds to the MapInfoOp corresponding to the private var in
5315 // question. So, for instance:
5316 //
5317 // %10 = omp.map.info var_ptr(%6#0 : !fir.ref<!fir.box<!fir.heap<i32>>>, ..)
5318 // omp.target map_entries(%10 -> %arg0) private(@box.privatizer %6#0-> %arg1)
5319 //
5320 // Then, %10 has been created so that the descriptor can be used by the
5321 // privatizer @box.privatizer on the device side. Here we'd record {%6#0,
5322 // %arg0} in the mappedPrivateVars map.
5323 llvm::DenseMap<Value, Value> mappedPrivateVars;
5324 DataLayout dl = DataLayout(opInst.getParentOfType<ModuleOp>());
5325 SmallVector<Value> mapVars = targetOp.getMapVars();
5326 SmallVector<Value> hdaVars = targetOp.getHasDeviceAddrVars();
5327 ArrayRef<BlockArgument> mapBlockArgs = argIface.getMapBlockArgs();
5328 ArrayRef<BlockArgument> hdaBlockArgs = argIface.getHasDeviceAddrBlockArgs();
5329 llvm::Function *llvmOutlinedFn = nullptr;
5330
5331 // TODO: It can also be false if a compile-time constant `false` IF clause is
5332 // specified.
5333 bool isOffloadEntry =
5334 isTargetDevice || !ompBuilder->Config.TargetTriples.empty();
5335
5336 // For some private variables, the MapsForPrivatizedVariablesPass
5337 // creates MapInfoOp instances. Go through the private variables and
5338 // the mapped variables so that during codegeneration we are able
5339 // to quickly look up the corresponding map variable, if any for each
5340 // private variable.
5341 if (!targetOp.getPrivateVars().empty() && !targetOp.getMapVars().empty()) {
5342 OperandRange privateVars = targetOp.getPrivateVars();
5343 std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
5344 std::optional<DenseI64ArrayAttr> privateMapIndices =
5345 targetOp.getPrivateMapsAttr();
5346
5347 for (auto [privVarIdx, privVarSymPair] :
5348 llvm::enumerate(llvm::zip_equal(privateVars, *privateSyms))) {
5349 auto privVar = std::get<0>(privVarSymPair);
5350 auto privSym = std::get<1>(privVarSymPair);
5351
5352 SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
5353 omp::PrivateClauseOp privatizer =
5354 findPrivatizer(targetOp, privatizerName);
5355
5356 if (!privatizer.needsMap())
5357 continue;
5358
5359 mlir::Value mappedValue =
5360 targetOp.getMappedValueForPrivateVar(privVarIdx);
5361 assert(mappedValue && "Expected to find mapped value for a privatized "
5362 "variable that needs mapping");
5363
5364 // The MapInfoOp defining the map var isn't really needed later.
5365 // So, we don't store it in any datastructure. Instead, we just
5366 // do some sanity checks on it right now.
5367 auto mapInfoOp = mappedValue.getDefiningOp<omp::MapInfoOp>();
5368 [[maybe_unused]] Type varType = mapInfoOp.getVarType();
5369
5370 // Check #1: Check that the type of the private variable matches
5371 // the type of the variable being mapped.
5372 if (!isa<LLVM::LLVMPointerType>(privVar.getType()))
5373 assert(
5374 varType == privVar.getType() &&
5375 "Type of private var doesn't match the type of the mapped value");
5376
5377 // Ok, only 1 sanity check for now.
5378 // Record the block argument corresponding to this mapvar.
5379 mappedPrivateVars.insert(
5380 {privVar,
5381 targetRegion.getArgument(argIface.getMapBlockArgsStart() +
5382 (*privateMapIndices)[privVarIdx])});
5383 }
5384 }
5385
5386 using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
5387 auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP)
5388 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
5389 llvm::IRBuilderBase::InsertPointGuard guard(builder);
5390 builder.SetCurrentDebugLocation(llvm::DebugLoc());
5391 // Forward target-cpu and target-features function attributes from the
5392 // original function to the new outlined function.
5393 llvm::Function *llvmParentFn =
5394 moduleTranslation.lookupFunction(name: parentFn.getName());
5395 llvmOutlinedFn = codeGenIP.getBlock()->getParent();
5396 assert(llvmParentFn && llvmOutlinedFn &&
5397 "Both parent and outlined functions must exist at this point");
5398
5399 if (auto attr = llvmParentFn->getFnAttribute(Kind: "target-cpu");
5400 attr.isStringAttribute())
5401 llvmOutlinedFn->addFnAttr(attr);
5402
5403 if (auto attr = llvmParentFn->getFnAttribute(Kind: "target-features");
5404 attr.isStringAttribute())
5405 llvmOutlinedFn->addFnAttr(attr);
5406
5407 for (auto [arg, mapOp] : llvm::zip_equal(mapBlockArgs, mapVars)) {
5408 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
5409 llvm::Value *mapOpValue =
5410 moduleTranslation.lookupValue(mapInfoOp.getVarPtr());
5411 moduleTranslation.mapValue(arg, mapOpValue);
5412 }
5413 for (auto [arg, mapOp] : llvm::zip_equal(hdaBlockArgs, hdaVars)) {
5414 auto mapInfoOp = cast<omp::MapInfoOp>(mapOp.getDefiningOp());
5415 llvm::Value *mapOpValue =
5416 moduleTranslation.lookupValue(mapInfoOp.getVarPtr());
5417 moduleTranslation.mapValue(arg, mapOpValue);
5418 }
5419
5420 // Do privatization after moduleTranslation has already recorded
5421 // mapped values.
5422 PrivateVarsInfo privateVarsInfo(targetOp);
5423
5424 llvm::Expected<llvm::BasicBlock *> afterAllocas =
5425 allocatePrivateVars(builder, moduleTranslation, privateVarsInfo,
5426 allocaIP, mappedPrivateVars: &mappedPrivateVars);
5427
5428 if (failed(handleError(afterAllocas, *targetOp)))
5429 return llvm::make_error<PreviouslyReportedError>();
5430
5431 builder.restoreIP(IP: codeGenIP);
5432 if (handleError(initPrivateVars(builder, moduleTranslation, privateVarsInfo,
5433 mappedPrivateVars: &mappedPrivateVars),
5434 *targetOp)
5435 .failed())
5436 return llvm::make_error<PreviouslyReportedError>();
5437
5438 if (failed(copyFirstPrivateVars(
5439 targetOp, builder, moduleTranslation, privateVarsInfo.mlirVars,
5440 privateVarsInfo.llvmVars, privateVarsInfo.privatizers,
5441 targetOp.getPrivateNeedsBarrier(), &mappedPrivateVars)))
5442 return llvm::make_error<PreviouslyReportedError>();
5443
5444 SmallVector<Region *> privateCleanupRegions;
5445 llvm::transform(privateVarsInfo.privatizers,
5446 std::back_inserter(x&: privateCleanupRegions),
5447 [](omp::PrivateClauseOp privatizer) {
5448 return &privatizer.getDeallocRegion();
5449 });
5450
5451 llvm::Expected<llvm::BasicBlock *> exitBlock = convertOmpOpRegions(
5452 targetRegion, "omp.target", builder, moduleTranslation);
5453
5454 if (!exitBlock)
5455 return exitBlock.takeError();
5456
5457 builder.SetInsertPoint(*exitBlock);
5458 if (!privateCleanupRegions.empty()) {
5459 if (failed(Result: inlineOmpRegionCleanup(
5460 cleanupRegions&: privateCleanupRegions, privateVariables: privateVarsInfo.llvmVars,
5461 moduleTranslation, builder, regionName: "omp.targetop.private.cleanup",
5462 /*shouldLoadCleanupRegionArg=*/false))) {
5463 return llvm::createStringError(
5464 Fmt: "failed to inline `dealloc` region of `omp.private` "
5465 "op in the target region");
5466 }
5467 return builder.saveIP();
5468 }
5469
5470 return InsertPointTy(exitBlock.get(), exitBlock.get()->end());
5471 };
5472
5473 StringRef parentName = parentFn.getName();
5474
5475 llvm::TargetRegionEntryInfo entryInfo;
5476
5477 getTargetEntryUniqueInfo(entryInfo, targetOp, parentName);
5478
5479 MapInfoData mapData;
5480 collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl,
5481 builder, /*useDevPtrOperands=*/{},
5482 /*useDevAddrOperands=*/{}, hasDevAddrOperands: hdaVars);
5483
5484 MapInfosTy combinedInfos;
5485 auto genMapInfoCB =
5486 [&](llvm::OpenMPIRBuilder::InsertPointTy codeGenIP) -> MapInfosTy & {
5487 builder.restoreIP(IP: codeGenIP);
5488 genMapInfos(builder, moduleTranslation, dl, combinedInfo&: combinedInfos, mapData, isTargetParams: true);
5489 return combinedInfos;
5490 };
5491
5492 auto argAccessorCB = [&](llvm::Argument &arg, llvm::Value *input,
5493 llvm::Value *&retVal, InsertPointTy allocaIP,
5494 InsertPointTy codeGenIP)
5495 -> llvm::OpenMPIRBuilder::InsertPointOrErrorTy {
5496 llvm::IRBuilderBase::InsertPointGuard guard(builder);
5497 builder.SetCurrentDebugLocation(llvm::DebugLoc());
5498 // We just return the unaltered argument for the host function
5499 // for now, some alterations may be required in the future to
5500 // keep host fallback functions working identically to the device
5501 // version (e.g. pass ByCopy values should be treated as such on
5502 // host and device, currently not always the case)
5503 if (!isTargetDevice) {
5504 retVal = cast<llvm::Value>(Val: &arg);
5505 return codeGenIP;
5506 }
5507
5508 return createDeviceArgumentAccessor(mapData, arg, input, retVal, builder,
5509 ompBuilder&: *ompBuilder, moduleTranslation,
5510 allocaIP, codeGenIP);
5511 };
5512
5513 llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
5514 llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs;
5515 Operation *targetCapturedOp = targetOp.getInnermostCapturedOmpOp();
5516 initTargetDefaultAttrs(targetOp, targetCapturedOp, defaultAttrs,
5517 isTargetDevice, isGPU);
5518
5519 // Collect host-evaluated values needed to properly launch the kernel from the
5520 // host.
5521 if (!isTargetDevice)
5522 initTargetRuntimeAttrs(builder, moduleTranslation, targetOp,
5523 targetCapturedOp, runtimeAttrs);
5524
5525 // Pass host-evaluated values as parameters to the kernel / host fallback,
5526 // except if they are constants. In any case, map the MLIR block argument to
5527 // the corresponding LLVM values.
5528 llvm::SmallVector<llvm::Value *, 4> kernelInput;
5529 SmallVector<Value> hostEvalVars = targetOp.getHostEvalVars();
5530 ArrayRef<BlockArgument> hostEvalBlockArgs = argIface.getHostEvalBlockArgs();
5531 for (auto [arg, var] : llvm::zip_equal(hostEvalBlockArgs, hostEvalVars)) {
5532 llvm::Value *value = moduleTranslation.lookupValue(var);
5533 moduleTranslation.mapValue(arg, value);
5534
5535 if (!llvm::isa<llvm::Constant>(value))
5536 kernelInput.push_back(value);
5537 }
5538
5539 for (size_t i = 0, e = mapData.OriginalValue.size(); i != e; ++i) {
5540 // declare target arguments are not passed to kernels as arguments
5541 // TODO: We currently do not handle cases where a member is explicitly
5542 // passed in as an argument, this will likley need to be handled in
5543 // the near future, rather than using IsAMember, it may be better to
5544 // test if the relevant BlockArg is used within the target region and
5545 // then use that as a basis for exclusion in the kernel inputs.
5546 if (!mapData.IsDeclareTarget[i] && !mapData.IsAMember[i])
5547 kernelInput.push_back(Elt: mapData.OriginalValue[i]);
5548 }
5549
5550 SmallVector<llvm::OpenMPIRBuilder::DependData> dds;
5551 buildDependData(targetOp.getDependKinds(), targetOp.getDependVars(),
5552 moduleTranslation, dds);
5553
5554 llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
5555 findAllocaInsertPoint(builder, moduleTranslation);
5556 llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
5557
5558 llvm::OpenMPIRBuilder::TargetDataInfo info(
5559 /*RequiresDevicePointerInfo=*/false,
5560 /*SeparateBeginEndCalls=*/true);
5561
5562 auto customMapperCB =
5563 [&](unsigned int i) -> llvm::Expected<llvm::Function *> {
5564 if (!combinedInfos.Mappers[i])
5565 return nullptr;
5566 info.HasMapper = true;
5567 return getOrCreateUserDefinedMapperFunc(op: combinedInfos.Mappers[i], builder,
5568 moduleTranslation);
5569 };
5570
5571 llvm::Value *ifCond = nullptr;
5572 if (Value targetIfCond = targetOp.getIfExpr())
5573 ifCond = moduleTranslation.lookupValue(value: targetIfCond);
5574
5575 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
5576 moduleTranslation.getOpenMPBuilder()->createTarget(
5577 Loc: ompLoc, IsOffloadEntry: isOffloadEntry, AllocaIP: allocaIP, CodeGenIP: builder.saveIP(), Info&: info, EntryInfo&: entryInfo,
5578 DefaultAttrs: defaultAttrs, RuntimeAttrs: runtimeAttrs, IfCond: ifCond, Inputs&: kernelInput, GenMapInfoCB: genMapInfoCB, BodyGenCB: bodyCB,
5579 ArgAccessorFuncCB: argAccessorCB, CustomMapperCB: customMapperCB, Dependencies: dds, HasNowait: targetOp.getNowait());
5580
5581 if (failed(Result: handleError(result&: afterIP, op&: opInst)))
5582 return failure();
5583
5584 builder.restoreIP(IP: *afterIP);
5585
5586 // Remap access operations to declare target reference pointers for the
5587 // device, essentially generating extra loadop's as necessary
5588 if (moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice())
5589 handleDeclareTargetMapVar(mapData, moduleTranslation, builder,
5590 func: llvmOutlinedFn);
5591
5592 return success();
5593}
5594
5595static LogicalResult
5596convertDeclareTargetAttr(Operation *op, mlir::omp::DeclareTargetAttr attribute,
5597 LLVM::ModuleTranslation &moduleTranslation) {
5598 // Amend omp.declare_target by deleting the IR of the outlined functions
5599 // created for target regions. They cannot be filtered out from MLIR earlier
5600 // because the omp.target operation inside must be translated to LLVM, but
5601 // the wrapper functions themselves must not remain at the end of the
5602 // process. We know that functions where omp.declare_target does not match
5603 // omp.is_target_device at this stage can only be wrapper functions because
5604 // those that aren't are removed earlier as an MLIR transformation pass.
5605 if (FunctionOpInterface funcOp = dyn_cast<FunctionOpInterface>(op)) {
5606 if (auto offloadMod = dyn_cast<omp::OffloadModuleInterface>(
5607 op->getParentOfType<ModuleOp>().getOperation())) {
5608 if (!offloadMod.getIsTargetDevice())
5609 return success();
5610
5611 omp::DeclareTargetDeviceType declareType =
5612 attribute.getDeviceType().getValue();
5613
5614 if (declareType == omp::DeclareTargetDeviceType::host) {
5615 llvm::Function *llvmFunc =
5616 moduleTranslation.lookupFunction(name: funcOp.getName());
5617 llvmFunc->dropAllReferences();
5618 llvmFunc->eraseFromParent();
5619 }
5620 }
5621 return success();
5622 }
5623
5624 if (LLVM::GlobalOp gOp = dyn_cast<LLVM::GlobalOp>(op)) {
5625 llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
5626 if (auto *gVal = llvmModule->getNamedValue(gOp.getSymName())) {
5627 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5628 bool isDeclaration = gOp.isDeclaration();
5629 bool isExternallyVisible =
5630 gOp.getVisibility() != mlir::SymbolTable::Visibility::Private;
5631 auto loc = op->getLoc()->findInstanceOf<FileLineColLoc>();
5632 llvm::StringRef mangledName = gOp.getSymName();
5633 auto captureClause =
5634 convertToCaptureClauseKind(attribute.getCaptureClause().getValue());
5635 auto deviceClause =
5636 convertToDeviceClauseKind(attribute.getDeviceType().getValue());
5637 // unused for MLIR at the moment, required in Clang for book
5638 // keeping
5639 std::vector<llvm::GlobalVariable *> generatedRefs;
5640
5641 std::vector<llvm::Triple> targetTriple;
5642 auto targetTripleAttr = dyn_cast_or_null<mlir::StringAttr>(
5643 op->getParentOfType<mlir::ModuleOp>()->getAttr(
5644 LLVM::LLVMDialect::getTargetTripleAttrName()));
5645 if (targetTripleAttr)
5646 targetTriple.emplace_back(targetTripleAttr.data());
5647
5648 auto fileInfoCallBack = [&loc]() {
5649 std::string filename = "";
5650 std::uint64_t lineNo = 0;
5651
5652 if (loc) {
5653 filename = loc.getFilename().str();
5654 lineNo = loc.getLine();
5655 }
5656
5657 return std::pair<std::string, std::uint64_t>(llvm::StringRef(filename),
5658 lineNo);
5659 };
5660
5661 ompBuilder->registerTargetGlobalVariable(
5662 CaptureClause: captureClause, DeviceClause: deviceClause, IsDeclaration: isDeclaration, IsExternallyVisible: isExternallyVisible,
5663 EntryInfo: ompBuilder->getTargetEntryUniqueInfo(CallBack: fileInfoCallBack), MangledName: mangledName,
5664 GeneratedRefs&: generatedRefs, /*OpenMPSimd*/ OpenMPSIMD: false, TargetTriple: targetTriple,
5665 /*GlobalInitializer*/ nullptr, /*VariableLinkage*/ nullptr,
5666 LlvmPtrTy: gVal->getType(), Addr: gVal);
5667
5668 if (ompBuilder->Config.isTargetDevice() &&
5669 (attribute.getCaptureClause().getValue() !=
5670 mlir::omp::DeclareTargetCaptureClause::to ||
5671 ompBuilder->Config.hasRequiresUnifiedSharedMemory())) {
5672 ompBuilder->getAddrOfDeclareTargetVar(
5673 CaptureClause: captureClause, DeviceClause: deviceClause, IsDeclaration: isDeclaration, IsExternallyVisible: isExternallyVisible,
5674 EntryInfo: ompBuilder->getTargetEntryUniqueInfo(CallBack: fileInfoCallBack), MangledName: mangledName,
5675 GeneratedRefs&: generatedRefs, /*OpenMPSimd*/ OpenMPSIMD: false, TargetTriple: targetTriple, LlvmPtrTy: gVal->getType(),
5676 /*GlobalInitializer*/ nullptr,
5677 /*VariableLinkage*/ nullptr);
5678 }
5679 }
5680 }
5681
5682 return success();
5683}
5684
5685// Returns true if the operation is inside a TargetOp or
5686// is part of a declare target function.
5687static bool isTargetDeviceOp(Operation *op) {
5688 // Assumes no reverse offloading
5689 if (op->getParentOfType<omp::TargetOp>())
5690 return true;
5691
5692 // Certain operations return results, and whether utilised in host or
5693 // target there is a chance an LLVM Dialect operation depends on it
5694 // by taking it in as an operand, so we must always lower these in
5695 // some manner or result in an ICE (whether they end up in a no-op
5696 // or otherwise).
5697 if (mlir::isa<omp::ThreadprivateOp>(op))
5698 return true;
5699
5700 if (auto parentFn = op->getParentOfType<LLVM::LLVMFuncOp>())
5701 if (auto declareTargetIface =
5702 llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
5703 parentFn.getOperation()))
5704 if (declareTargetIface.isDeclareTarget() &&
5705 declareTargetIface.getDeclareTargetDeviceType() !=
5706 mlir::omp::DeclareTargetDeviceType::host)
5707 return true;
5708
5709 return false;
5710}
5711
5712/// Given an OpenMP MLIR operation, create the corresponding LLVM IR (including
5713/// OpenMP runtime calls).
5714static LogicalResult
5715convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder,
5716 LLVM::ModuleTranslation &moduleTranslation) {
5717 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
5718
5719 // For each loop, introduce one stack frame to hold loop information. Ensure
5720 // this is only done for the outermost loop wrapper to prevent introducing
5721 // multiple stack frames for a single loop. Initially set to null, the loop
5722 // information structure is initialized during translation of the nested
5723 // omp.loop_nest operation, making it available to translation of all loop
5724 // wrappers after their body has been successfully translated.
5725 bool isOutermostLoopWrapper =
5726 isa_and_present<omp::LoopWrapperInterface>(op) &&
5727 !dyn_cast_if_present<omp::LoopWrapperInterface>(op->getParentOp());
5728
5729 if (isOutermostLoopWrapper)
5730 moduleTranslation.stackPush<OpenMPLoopInfoStackFrame>();
5731
5732 auto result =
5733 llvm::TypeSwitch<Operation *, LogicalResult>(op)
5734 .Case([&](omp::BarrierOp op) -> LogicalResult {
5735 if (failed(checkImplementationStatus(*op)))
5736 return failure();
5737
5738 llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
5739 ompBuilder->createBarrier(builder.saveIP(),
5740 llvm::omp::OMPD_barrier);
5741 LogicalResult res = handleError(afterIP, *op);
5742 if (res.succeeded()) {
5743 // If the barrier generated a cancellation check, the insertion
5744 // point might now need to be changed to a new continuation block
5745 builder.restoreIP(*afterIP);
5746 }
5747 return res;
5748 })
5749 .Case([&](omp::TaskyieldOp op) {
5750 if (failed(checkImplementationStatus(*op)))
5751 return failure();
5752
5753 ompBuilder->createTaskyield(builder.saveIP());
5754 return success();
5755 })
5756 .Case([&](omp::FlushOp op) {
5757 if (failed(checkImplementationStatus(*op)))
5758 return failure();
5759
5760 // No support in Openmp runtime function (__kmpc_flush) to accept
5761 // the argument list.
5762 // OpenMP standard states the following:
5763 // "An implementation may implement a flush with a list by ignoring
5764 // the list, and treating it the same as a flush without a list."
5765 //
5766 // The argument list is discarded so that, flush with a list is
5767 // treated same as a flush without a list.
5768 ompBuilder->createFlush(builder.saveIP());
5769 return success();
5770 })
5771 .Case([&](omp::ParallelOp op) {
5772 return convertOmpParallel(op, builder, moduleTranslation);
5773 })
5774 .Case([&](omp::MaskedOp) {
5775 return convertOmpMasked(*op, builder, moduleTranslation);
5776 })
5777 .Case([&](omp::MasterOp) {
5778 return convertOmpMaster(*op, builder, moduleTranslation);
5779 })
5780 .Case([&](omp::CriticalOp) {
5781 return convertOmpCritical(*op, builder, moduleTranslation);
5782 })
5783 .Case([&](omp::OrderedRegionOp) {
5784 return convertOmpOrderedRegion(*op, builder, moduleTranslation);
5785 })
5786 .Case([&](omp::OrderedOp) {
5787 return convertOmpOrdered(*op, builder, moduleTranslation);
5788 })
5789 .Case([&](omp::WsloopOp) {
5790 return convertOmpWsloop(*op, builder, moduleTranslation);
5791 })
5792 .Case([&](omp::SimdOp) {
5793 return convertOmpSimd(*op, builder, moduleTranslation);
5794 })
5795 .Case([&](omp::AtomicReadOp) {
5796 return convertOmpAtomicRead(*op, builder, moduleTranslation);
5797 })
5798 .Case([&](omp::AtomicWriteOp) {
5799 return convertOmpAtomicWrite(*op, builder, moduleTranslation);
5800 })
5801 .Case([&](omp::AtomicUpdateOp op) {
5802 return convertOmpAtomicUpdate(op, builder, moduleTranslation);
5803 })
5804 .Case([&](omp::AtomicCaptureOp op) {
5805 return convertOmpAtomicCapture(op, builder, moduleTranslation);
5806 })
5807 .Case([&](omp::CancelOp op) {
5808 return convertOmpCancel(op, builder, moduleTranslation);
5809 })
5810 .Case([&](omp::CancellationPointOp op) {
5811 return convertOmpCancellationPoint(op, builder, moduleTranslation);
5812 })
5813 .Case([&](omp::SectionsOp) {
5814 return convertOmpSections(*op, builder, moduleTranslation);
5815 })
5816 .Case([&](omp::SingleOp op) {
5817 return convertOmpSingle(op, builder, moduleTranslation);
5818 })
5819 .Case([&](omp::TeamsOp op) {
5820 return convertOmpTeams(op, builder, moduleTranslation);
5821 })
5822 .Case([&](omp::TaskOp op) {
5823 return convertOmpTaskOp(op, builder, moduleTranslation);
5824 })
5825 .Case([&](omp::TaskgroupOp op) {
5826 return convertOmpTaskgroupOp(op, builder, moduleTranslation);
5827 })
5828 .Case([&](omp::TaskwaitOp op) {
5829 return convertOmpTaskwaitOp(op, builder, moduleTranslation);
5830 })
5831 .Case<omp::YieldOp, omp::TerminatorOp, omp::DeclareMapperOp,
5832 omp::DeclareMapperInfoOp, omp::DeclareReductionOp,
5833 omp::CriticalDeclareOp>([](auto op) {
5834 // `yield` and `terminator` can be just omitted. The block structure
5835 // was created in the region that handles their parent operation.
5836 // `declare_reduction` will be used by reductions and is not
5837 // converted directly, skip it.
5838 // `declare_mapper` and `declare_mapper.info` are handled whenever
5839 // they are referred to through a `map` clause.
5840 // `critical.declare` is only used to declare names of critical
5841 // sections which will be used by `critical` ops and hence can be
5842 // ignored for lowering. The OpenMP IRBuilder will create unique
5843 // name for critical section names.
5844 return success();
5845 })
5846 .Case([&](omp::ThreadprivateOp) {
5847 return convertOmpThreadprivate(*op, builder, moduleTranslation);
5848 })
5849 .Case<omp::TargetDataOp, omp::TargetEnterDataOp,
5850 omp::TargetExitDataOp, omp::TargetUpdateOp>([&](auto op) {
5851 return convertOmpTargetData(op, builder, moduleTranslation);
5852 })
5853 .Case([&](omp::TargetOp) {
5854 return convertOmpTarget(*op, builder, moduleTranslation);
5855 })
5856 .Case([&](omp::DistributeOp) {
5857 return convertOmpDistribute(*op, builder, moduleTranslation);
5858 })
5859 .Case([&](omp::LoopNestOp) {
5860 return convertOmpLoopNest(*op, builder, moduleTranslation);
5861 })
5862 .Case<omp::MapInfoOp, omp::MapBoundsOp, omp::PrivateClauseOp>(
5863 [&](auto op) {
5864 // No-op, should be handled by relevant owning operations e.g.
5865 // TargetOp, TargetEnterDataOp, TargetExitDataOp, TargetDataOp
5866 // etc. and then discarded
5867 return success();
5868 })
5869 .Default([&](Operation *inst) {
5870 return inst->emitError()
5871 << "not yet implemented: " << inst->getName();
5872 });
5873
5874 if (isOutermostLoopWrapper)
5875 moduleTranslation.stackPop();
5876
5877 return result;
5878}
5879
5880static LogicalResult
5881convertTargetDeviceOp(Operation *op, llvm::IRBuilderBase &builder,
5882 LLVM::ModuleTranslation &moduleTranslation) {
5883 return convertHostOrTargetOperation(op, builder, moduleTranslation);
5884}
5885
5886static LogicalResult
5887convertTargetOpsInNest(Operation *op, llvm::IRBuilderBase &builder,
5888 LLVM::ModuleTranslation &moduleTranslation) {
5889 if (isa<omp::TargetOp>(op))
5890 return convertOmpTarget(opInst&: *op, builder, moduleTranslation);
5891 if (isa<omp::TargetDataOp>(op))
5892 return convertOmpTargetData(op, builder, moduleTranslation);
5893 bool interrupted =
5894 op->walk<WalkOrder::PreOrder>(callback: [&](Operation *oper) {
5895 if (isa<omp::TargetOp>(oper)) {
5896 if (failed(Result: convertOmpTarget(opInst&: *oper, builder, moduleTranslation)))
5897 return WalkResult::interrupt();
5898 return WalkResult::skip();
5899 }
5900 if (isa<omp::TargetDataOp>(oper)) {
5901 if (failed(Result: convertOmpTargetData(op: oper, builder, moduleTranslation)))
5902 return WalkResult::interrupt();
5903 return WalkResult::skip();
5904 }
5905
5906 // Non-target ops might nest target-related ops, therefore, we
5907 // translate them as non-OpenMP scopes. Translating them is needed by
5908 // nested target-related ops since they might need LLVM values defined
5909 // in their parent non-target ops.
5910 if (isa<omp::OpenMPDialect>(oper->getDialect()) &&
5911 oper->getParentOfType<LLVM::LLVMFuncOp>() &&
5912 !oper->getRegions().empty()) {
5913 if (auto blockArgsIface =
5914 dyn_cast<omp::BlockArgOpenMPOpInterface>(oper))
5915 forwardArgs(moduleTranslation, blockArgsIface);
5916 else {
5917 // Here we map entry block arguments of
5918 // non-BlockArgOpenMPOpInterface ops if they can be encountered
5919 // inside of a function and they define any of these arguments.
5920 if (isa<mlir::omp::AtomicUpdateOp>(oper))
5921 for (auto [operand, arg] :
5922 llvm::zip_equal(t: oper->getOperands(),
5923 u: oper->getRegion(index: 0).getArguments())) {
5924 moduleTranslation.mapValue(
5925 mlir: arg, llvm: builder.CreateLoad(
5926 Ty: moduleTranslation.convertType(type: arg.getType()),
5927 Ptr: moduleTranslation.lookupValue(value: operand)));
5928 }
5929 }
5930
5931 if (auto loopNest = dyn_cast<omp::LoopNestOp>(oper)) {
5932 assert(builder.GetInsertBlock() &&
5933 "No insert block is set for the builder");
5934 for (auto iv : loopNest.getIVs()) {
5935 // Map iv to an undefined value just to keep the IR validity.
5936 moduleTranslation.mapValue(
5937 iv, llvm::PoisonValue::get(
5938 moduleTranslation.convertType(iv.getType())));
5939 }
5940 }
5941
5942 for (Region &region : oper->getRegions()) {
5943 // Regions are fake in the sense that they are not a truthful
5944 // translation of the OpenMP construct being converted (e.g. no
5945 // OpenMP runtime calls will be generated). We just need this to
5946 // prepare the kernel invocation args.
5947 SmallVector<llvm::PHINode *> phis;
5948 auto result = convertOmpOpRegions(
5949 region, blockName: oper->getName().getStringRef().str() + ".fake.region",
5950 builder, moduleTranslation, continuationBlockPHIs: &phis);
5951 if (failed(Result: handleError(result, op&: *oper)))
5952 return WalkResult::interrupt();
5953
5954 builder.SetInsertPoint(TheBB: result.get(), IP: result.get()->end());
5955 }
5956
5957 return WalkResult::skip();
5958 }
5959
5960 return WalkResult::advance();
5961 }).wasInterrupted();
5962 return failure(IsFailure: interrupted);
5963}
5964
5965namespace {
5966
5967/// Implementation of the dialect interface that converts operations belonging
5968/// to the OpenMP dialect to LLVM IR.
5969class OpenMPDialectLLVMIRTranslationInterface
5970 : public LLVMTranslationDialectInterface {
5971public:
5972 using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
5973
5974 /// Translates the given operation to LLVM IR using the provided IR builder
5975 /// and saving the state in `moduleTranslation`.
5976 LogicalResult
5977 convertOperation(Operation *op, llvm::IRBuilderBase &builder,
5978 LLVM::ModuleTranslation &moduleTranslation) const final;
5979
5980 /// Given an OpenMP MLIR attribute, create the corresponding LLVM-IR,
5981 /// runtime calls, or operation amendments
5982 LogicalResult
5983 amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
5984 NamedAttribute attribute,
5985 LLVM::ModuleTranslation &moduleTranslation) const final;
5986};
5987
5988} // namespace
5989
5990LogicalResult OpenMPDialectLLVMIRTranslationInterface::amendOperation(
5991 Operation *op, ArrayRef<llvm::Instruction *> instructions,
5992 NamedAttribute attribute,
5993 LLVM::ModuleTranslation &moduleTranslation) const {
5994 return llvm::StringSwitch<llvm::function_ref<LogicalResult(Attribute)>>(
5995 attribute.getName())
5996 .Case(S: "omp.is_target_device",
5997 Value: [&](Attribute attr) {
5998 if (auto deviceAttr = dyn_cast<BoolAttr>(Val&: attr)) {
5999 llvm::OpenMPIRBuilderConfig &config =
6000 moduleTranslation.getOpenMPBuilder()->Config;
6001 config.setIsTargetDevice(deviceAttr.getValue());
6002 return success();
6003 }
6004 return failure();
6005 })
6006 .Case(S: "omp.is_gpu",
6007 Value: [&](Attribute attr) {
6008 if (auto gpuAttr = dyn_cast<BoolAttr>(Val&: attr)) {
6009 llvm::OpenMPIRBuilderConfig &config =
6010 moduleTranslation.getOpenMPBuilder()->Config;
6011 config.setIsGPU(gpuAttr.getValue());
6012 return success();
6013 }
6014 return failure();
6015 })
6016 .Case(S: "omp.host_ir_filepath",
6017 Value: [&](Attribute attr) {
6018 if (auto filepathAttr = dyn_cast<StringAttr>(attr)) {
6019 llvm::OpenMPIRBuilder *ompBuilder =
6020 moduleTranslation.getOpenMPBuilder();
6021 ompBuilder->loadOffloadInfoMetadata(filepathAttr.getValue());
6022 return success();
6023 }
6024 return failure();
6025 })
6026 .Case(S: "omp.flags",
6027 Value: [&](Attribute attr) {
6028 if (auto rtlAttr = dyn_cast<omp::FlagsAttr>(attr))
6029 return convertFlagsAttr(op, rtlAttr, moduleTranslation);
6030 return failure();
6031 })
6032 .Case(S: "omp.version",
6033 Value: [&](Attribute attr) {
6034 if (auto versionAttr = dyn_cast<omp::VersionAttr>(attr)) {
6035 llvm::OpenMPIRBuilder *ompBuilder =
6036 moduleTranslation.getOpenMPBuilder();
6037 ompBuilder->M.addModuleFlag(llvm::Module::Max, "openmp",
6038 versionAttr.getVersion());
6039 return success();
6040 }
6041 return failure();
6042 })
6043 .Case(S: "omp.declare_target",
6044 Value: [&](Attribute attr) {
6045 if (auto declareTargetAttr =
6046 dyn_cast<omp::DeclareTargetAttr>(attr))
6047 return convertDeclareTargetAttr(op, declareTargetAttr,
6048 moduleTranslation);
6049 return failure();
6050 })
6051 .Case(S: "omp.requires",
6052 Value: [&](Attribute attr) {
6053 if (auto requiresAttr = dyn_cast<omp::ClauseRequiresAttr>(attr)) {
6054 using Requires = omp::ClauseRequires;
6055 Requires flags = requiresAttr.getValue();
6056 llvm::OpenMPIRBuilderConfig &config =
6057 moduleTranslation.getOpenMPBuilder()->Config;
6058 config.setHasRequiresReverseOffload(
6059 bitEnumContainsAll(flags, Requires::reverse_offload));
6060 config.setHasRequiresUnifiedAddress(
6061 bitEnumContainsAll(flags, Requires::unified_address));
6062 config.setHasRequiresUnifiedSharedMemory(
6063 bitEnumContainsAll(flags, Requires::unified_shared_memory));
6064 config.setHasRequiresDynamicAllocators(
6065 bitEnumContainsAll(flags, Requires::dynamic_allocators));
6066 return success();
6067 }
6068 return failure();
6069 })
6070 .Case(S: "omp.target_triples",
6071 Value: [&](Attribute attr) {
6072 if (auto triplesAttr = dyn_cast<ArrayAttr>(attr)) {
6073 llvm::OpenMPIRBuilderConfig &config =
6074 moduleTranslation.getOpenMPBuilder()->Config;
6075 config.TargetTriples.clear();
6076 config.TargetTriples.reserve(N: triplesAttr.size());
6077 for (Attribute tripleAttr : triplesAttr) {
6078 if (auto tripleStrAttr = dyn_cast<StringAttr>(tripleAttr))
6079 config.TargetTriples.emplace_back(tripleStrAttr.getValue());
6080 else
6081 return failure();
6082 }
6083 return success();
6084 }
6085 return failure();
6086 })
6087 .Default(Value: [](Attribute) {
6088 // Fall through for omp attributes that do not require lowering.
6089 return success();
6090 })(attribute.getValue());
6091
6092 return failure();
6093}
6094
6095/// Given an OpenMP MLIR operation, create the corresponding LLVM IR
6096/// (including OpenMP runtime calls).
6097LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation(
6098 Operation *op, llvm::IRBuilderBase &builder,
6099 LLVM::ModuleTranslation &moduleTranslation) const {
6100
6101 llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
6102 if (ompBuilder->Config.isTargetDevice()) {
6103 if (isTargetDeviceOp(op)) {
6104 return convertTargetDeviceOp(op, builder, moduleTranslation);
6105 } else {
6106 return convertTargetOpsInNest(op, builder, moduleTranslation);
6107 }
6108 }
6109 return convertHostOrTargetOperation(op, builder, moduleTranslation);
6110}
6111
6112void mlir::registerOpenMPDialectTranslation(DialectRegistry &registry) {
6113 registry.insert<omp::OpenMPDialect>();
6114 registry.addExtension(extensionFn: +[](MLIRContext *ctx, omp::OpenMPDialect *dialect) {
6115 dialect->addInterfaces<OpenMPDialectLLVMIRTranslationInterface>();
6116 });
6117}
6118
6119void mlir::registerOpenMPDialectTranslation(MLIRContext &context) {
6120 DialectRegistry registry;
6121 registerOpenMPDialectTranslation(registry);
6122 context.appendDialectRegistry(registry);
6123}
6124

source code of mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp