1//===----- Workshare.cpp - OpenMP workshare implementation ------ C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the implementation of the KMPC interface
10// for the loop construct plus other worksharing constructs that use the same
11// interface as loops.
12//
13//===----------------------------------------------------------------------===//
14
15#include "Debug.h"
16#include "Interface.h"
17#include "Mapping.h"
18#include "State.h"
19#include "Synchronization.h"
20#include "Types.h"
21#include "Utils.h"
22
23using namespace ompx;
24
25// TODO:
26struct DynamicScheduleTracker {
27 int64_t Chunk;
28 int64_t LoopUpperBound;
29 int64_t NextLowerBound;
30 int64_t Stride;
31 kmp_sched_t ScheduleType;
32 DynamicScheduleTracker *NextDST;
33};
34
35#define ASSERT0(...)
36
37// used by the library for the interface with the app
38#define DISPATCH_FINISHED 0
39#define DISPATCH_NOTFINISHED 1
40
41// used by dynamic scheduling
42#define FINISHED 0
43#define NOT_FINISHED 1
44#define LAST_CHUNK 2
45
46#pragma omp begin declare target device_type(nohost)
47
48// TODO: This variable is a hack inherited from the old runtime.
49static uint64_t SHARED(Cnt);
50
51template <typename T, typename ST> struct omptarget_nvptx_LoopSupport {
52 ////////////////////////////////////////////////////////////////////////////////
53 // Loop with static scheduling with chunk
54
55 // Generic implementation of OMP loop scheduling with static policy
56 /*! \brief Calculate initial bounds for static loop and stride
57 * @param[in] loc location in code of the call (not used here)
58 * @param[in] global_tid global thread id
59 * @param[in] schetype type of scheduling (see omptarget-nvptx.h)
60 * @param[in] plastiter pointer to last iteration
61 * @param[in,out] pointer to loop lower bound. it will contain value of
62 * lower bound of first chunk
63 * @param[in,out] pointer to loop upper bound. It will contain value of
64 * upper bound of first chunk
65 * @param[in,out] pointer to loop stride. It will contain value of stride
66 * between two successive chunks executed by the same thread
67 * @param[in] loop increment bump
68 * @param[in] chunk size
69 */
70
71 // helper function for static chunk
72 static void ForStaticChunk(int &last, T &lb, T &ub, ST &stride, ST chunk,
73 T entityId, T numberOfEntities) {
74 // each thread executes multiple chunks all of the same size, except
75 // the last one
76 // distance between two successive chunks
77 stride = numberOfEntities * chunk;
78 lb = lb + entityId * chunk;
79 T inputUb = ub;
80 ub = lb + chunk - 1; // Clang uses i <= ub
81 // Say ub' is the begining of the last chunk. Then who ever has a
82 // lower bound plus a multiple of the increment equal to ub' is
83 // the last one.
84 T beginingLastChunk = inputUb - (inputUb % chunk);
85 last = ((beginingLastChunk - lb) % stride) == 0;
86 }
87
88 ////////////////////////////////////////////////////////////////////////////////
89 // Loop with static scheduling without chunk
90
91 // helper function for static no chunk
92 static void ForStaticNoChunk(int &last, T &lb, T &ub, ST &stride, ST &chunk,
93 T entityId, T numberOfEntities) {
94 // No chunk size specified. Each thread or warp gets at most one
95 // chunk; chunks are all almost of equal size
96 T loopSize = ub - lb + 1;
97
98 chunk = loopSize / numberOfEntities;
99 T leftOver = loopSize - chunk * numberOfEntities;
100
101 if (entityId < leftOver) {
102 chunk++;
103 lb = lb + entityId * chunk;
104 } else {
105 lb = lb + entityId * chunk + leftOver;
106 }
107
108 T inputUb = ub;
109 ub = lb + chunk - 1; // Clang uses i <= ub
110 last = lb <= inputUb && inputUb <= ub;
111 stride = loopSize; // make sure we only do 1 chunk per warp
112 }
113
114 ////////////////////////////////////////////////////////////////////////////////
115 // Support for Static Init
116
117 static void for_static_init(int32_t, int32_t schedtype, int32_t *plastiter,
118 T *plower, T *pupper, ST *pstride, ST chunk,
119 bool IsSPMDExecutionMode) {
120 int32_t gtid = omp_get_thread_num();
121 int numberOfActiveOMPThreads = omp_get_num_threads();
122
123 // All warps that are in excess of the maximum requested, do
124 // not execute the loop
125 ASSERT0(LT_FUSSY, gtid < numberOfActiveOMPThreads,
126 "current thread is not needed here; error");
127
128 // copy
129 int lastiter = 0;
130 T lb = *plower;
131 T ub = *pupper;
132 ST stride = *pstride;
133
134 // init
135 switch (SCHEDULE_WITHOUT_MODIFIERS(schedtype)) {
136 case kmp_sched_static_chunk: {
137 if (chunk > 0) {
138 ForStaticChunk(last&: lastiter, lb, ub, stride, chunk, entityId: gtid,
139 numberOfEntities: numberOfActiveOMPThreads);
140 break;
141 }
142 [[fallthrough]];
143 } // note: if chunk <=0, use nochunk
144 case kmp_sched_static_balanced_chunk: {
145 if (chunk > 0) {
146 // round up to make sure the chunk is enough to cover all iterations
147 T tripCount = ub - lb + 1; // +1 because ub is inclusive
148 T span = (tripCount + numberOfActiveOMPThreads - 1) /
149 numberOfActiveOMPThreads;
150 // perform chunk adjustment
151 chunk = (span + chunk - 1) & ~(chunk - 1);
152
153 ASSERT0(LT_FUSSY, ub >= lb, "ub must be >= lb.");
154 T oldUb = ub;
155 ForStaticChunk(last&: lastiter, lb, ub, stride, chunk, entityId: gtid,
156 numberOfEntities: numberOfActiveOMPThreads);
157 if (ub > oldUb)
158 ub = oldUb;
159 break;
160 }
161 [[fallthrough]];
162 } // note: if chunk <=0, use nochunk
163 case kmp_sched_static_nochunk: {
164 ForStaticNoChunk(last&: lastiter, lb, ub, stride, chunk, entityId: gtid,
165 numberOfEntities: numberOfActiveOMPThreads);
166 break;
167 }
168 case kmp_sched_distr_static_chunk: {
169 if (chunk > 0) {
170 ForStaticChunk(last&: lastiter, lb, ub, stride, chunk, entityId: omp_get_team_num(),
171 numberOfEntities: omp_get_num_teams());
172 break;
173 }
174 [[fallthrough]];
175 } // note: if chunk <=0, use nochunk
176 case kmp_sched_distr_static_nochunk: {
177 ForStaticNoChunk(last&: lastiter, lb, ub, stride, chunk, entityId: omp_get_team_num(),
178 numberOfEntities: omp_get_num_teams());
179 break;
180 }
181 case kmp_sched_distr_static_chunk_sched_static_chunkone: {
182 ForStaticChunk(last&: lastiter, lb, ub, stride, chunk,
183 entityId: numberOfActiveOMPThreads * omp_get_team_num() + gtid,
184 numberOfEntities: omp_get_num_teams() * numberOfActiveOMPThreads);
185 break;
186 }
187 default: {
188 // ASSERT(LT_FUSSY, 0, "unknown schedtype %d", (int)schedtype);
189 ForStaticChunk(last&: lastiter, lb, ub, stride, chunk, entityId: gtid,
190 numberOfEntities: numberOfActiveOMPThreads);
191 break;
192 }
193 }
194 // copy back
195 *plastiter = lastiter;
196 *plower = lb;
197 *pupper = ub;
198 *pstride = stride;
199 }
200
201 ////////////////////////////////////////////////////////////////////////////////
202 // Support for dispatch Init
203
204 static int OrderedSchedule(kmp_sched_t schedule) {
205 return schedule >= kmp_sched_ordered_first &&
206 schedule <= kmp_sched_ordered_last;
207 }
208
209 static void dispatch_init(IdentTy *loc, int32_t threadId,
210 kmp_sched_t schedule, T lb, T ub, ST st, ST chunk,
211 DynamicScheduleTracker *DST) {
212 int tid = mapping::getThreadIdInBlock();
213 T tnum = omp_get_num_threads();
214 T tripCount = ub - lb + 1; // +1 because ub is inclusive
215 ASSERT0(LT_FUSSY, threadId < tnum,
216 "current thread is not needed here; error");
217
218 /* Currently just ignore the monotonic and non-monotonic modifiers
219 * (the compiler isn't producing them * yet anyway).
220 * When it is we'll want to look at them somewhere here and use that
221 * information to add to our schedule choice. We shouldn't need to pass
222 * them on, they merely affect which schedule we can legally choose for
223 * various dynamic cases. (In particular, whether or not a stealing scheme
224 * is legal).
225 */
226 schedule = SCHEDULE_WITHOUT_MODIFIERS(schedule);
227
228 // Process schedule.
229 if (tnum == 1 || tripCount <= 1 || OrderedSchedule(schedule)) {
230 if (OrderedSchedule(schedule))
231 __kmpc_barrier(loc, threadId);
232 schedule = kmp_sched_static_chunk;
233 chunk = tripCount; // one thread gets the whole loop
234 } else if (schedule == kmp_sched_runtime) {
235 // process runtime
236 omp_sched_t rtSched;
237 int ChunkInt;
238 omp_get_schedule(&rtSched, &ChunkInt);
239 chunk = ChunkInt;
240 switch (rtSched) {
241 case omp_sched_static: {
242 if (chunk > 0)
243 schedule = kmp_sched_static_chunk;
244 else
245 schedule = kmp_sched_static_nochunk;
246 break;
247 }
248 case omp_sched_auto: {
249 schedule = kmp_sched_static_chunk;
250 chunk = 1;
251 break;
252 }
253 case omp_sched_dynamic:
254 case omp_sched_guided: {
255 schedule = kmp_sched_dynamic;
256 break;
257 }
258 }
259 } else if (schedule == kmp_sched_auto) {
260 schedule = kmp_sched_static_chunk;
261 chunk = 1;
262 } else {
263 // ASSERT(LT_FUSSY,
264 // schedule == kmp_sched_dynamic || schedule == kmp_sched_guided,
265 // "unknown schedule %d & chunk %lld\n", (int)schedule,
266 // (long long)chunk);
267 }
268
269 // init schedules
270 if (schedule == kmp_sched_static_chunk) {
271 ASSERT0(LT_FUSSY, chunk > 0, "bad chunk value");
272 // save sched state
273 DST->ScheduleType = schedule;
274 // save ub
275 DST->LoopUpperBound = ub;
276 // compute static chunk
277 ST stride;
278 int lastiter = 0;
279 ForStaticChunk(last&: lastiter, lb, ub, stride, chunk, entityId: threadId, numberOfEntities: tnum);
280 // save computed params
281 DST->Chunk = chunk;
282 DST->NextLowerBound = lb;
283 DST->Stride = stride;
284 } else if (schedule == kmp_sched_static_balanced_chunk) {
285 ASSERT0(LT_FUSSY, chunk > 0, "bad chunk value");
286 // save sched state
287 DST->ScheduleType = schedule;
288 // save ub
289 DST->LoopUpperBound = ub;
290 // compute static chunk
291 ST stride;
292 int lastiter = 0;
293 // round up to make sure the chunk is enough to cover all iterations
294 T span = (tripCount + tnum - 1) / tnum;
295 // perform chunk adjustment
296 chunk = (span + chunk - 1) & ~(chunk - 1);
297
298 T oldUb = ub;
299 ForStaticChunk(last&: lastiter, lb, ub, stride, chunk, entityId: threadId, numberOfEntities: tnum);
300 ASSERT0(LT_FUSSY, ub >= lb, "ub must be >= lb.");
301 if (ub > oldUb)
302 ub = oldUb;
303 // save computed params
304 DST->Chunk = chunk;
305 DST->NextLowerBound = lb;
306 DST->Stride = stride;
307 } else if (schedule == kmp_sched_static_nochunk) {
308 ASSERT0(LT_FUSSY, chunk == 0, "bad chunk value");
309 // save sched state
310 DST->ScheduleType = schedule;
311 // save ub
312 DST->LoopUpperBound = ub;
313 // compute static chunk
314 ST stride;
315 int lastiter = 0;
316 ForStaticNoChunk(last&: lastiter, lb, ub, stride, chunk, entityId: threadId, numberOfEntities: tnum);
317 // save computed params
318 DST->Chunk = chunk;
319 DST->NextLowerBound = lb;
320 DST->Stride = stride;
321 } else if (schedule == kmp_sched_dynamic || schedule == kmp_sched_guided) {
322 // save data
323 DST->ScheduleType = schedule;
324 if (chunk < 1)
325 chunk = 1;
326 DST->Chunk = chunk;
327 DST->LoopUpperBound = ub;
328 DST->NextLowerBound = lb;
329 __kmpc_barrier(loc, threadId);
330 if (tid == 0) {
331 Cnt = 0;
332 fence::team(atomic::seq_cst);
333 }
334 __kmpc_barrier(loc, threadId);
335 }
336 }
337
338 ////////////////////////////////////////////////////////////////////////////////
339 // Support for dispatch next
340
341 static uint64_t NextIter() {
342 __kmpc_impl_lanemask_t active = mapping::activemask();
343 uint32_t leader = utils::ffs(active) - 1;
344 uint32_t change = utils::popc(active);
345 __kmpc_impl_lanemask_t lane_mask_lt = mapping::lanemaskLT();
346 unsigned int rank = utils::popc(active & lane_mask_lt);
347 uint64_t warp_res = 0;
348 if (rank == 0) {
349 warp_res = atomic::add(&Cnt, change, atomic::seq_cst);
350 }
351 warp_res = utils::shuffle(active, warp_res, leader);
352 return warp_res + rank;
353 }
354
355 static int DynamicNextChunk(T &lb, T &ub, T chunkSize, T loopLowerBound,
356 T loopUpperBound) {
357 T N = NextIter();
358 lb = loopLowerBound + N * chunkSize;
359 ub = lb + chunkSize - 1; // Clang uses i <= ub
360
361 // 3 result cases:
362 // a. lb and ub < loopUpperBound --> NOT_FINISHED
363 // b. lb < loopUpperBound and ub >= loopUpperBound: last chunk -->
364 // NOT_FINISHED
365 // c. lb and ub >= loopUpperBound: empty chunk --> FINISHED
366 // a.
367 if (lb <= loopUpperBound && ub < loopUpperBound) {
368 return NOT_FINISHED;
369 }
370 // b.
371 if (lb <= loopUpperBound) {
372 ub = loopUpperBound;
373 return LAST_CHUNK;
374 }
375 // c. if we are here, we are in case 'c'
376 lb = loopUpperBound + 2;
377 ub = loopUpperBound + 1;
378 return FINISHED;
379 }
380
381 static int dispatch_next(IdentTy *loc, int32_t gtid, int32_t *plast,
382 T *plower, T *pupper, ST *pstride,
383 DynamicScheduleTracker *DST) {
384 // ID of a thread in its own warp
385
386 // automatically selects thread or warp ID based on selected implementation
387 ASSERT0(LT_FUSSY, gtid < omp_get_num_threads(),
388 "current thread is not needed here; error");
389 // retrieve schedule
390 kmp_sched_t schedule = DST->ScheduleType;
391
392 // xxx reduce to one
393 if (schedule == kmp_sched_static_chunk ||
394 schedule == kmp_sched_static_nochunk) {
395 T myLb = DST->NextLowerBound;
396 T ub = DST->LoopUpperBound;
397 // finished?
398 if (myLb > ub) {
399 return DISPATCH_FINISHED;
400 }
401 // not finished, save current bounds
402 ST chunk = DST->Chunk;
403 *plower = myLb;
404 T myUb = myLb + chunk - 1; // Clang uses i <= ub
405 if (myUb > ub)
406 myUb = ub;
407 *pupper = myUb;
408 *plast = (int32_t)(myUb == ub);
409
410 // increment next lower bound by the stride
411 ST stride = DST->Stride;
412 DST->NextLowerBound = myLb + stride;
413 return DISPATCH_NOTFINISHED;
414 }
415 ASSERT0(LT_FUSSY,
416 schedule == kmp_sched_dynamic || schedule == kmp_sched_guided,
417 "bad sched");
418 T myLb, myUb;
419 int finished = DynamicNextChunk(lb&: myLb, ub&: myUb, chunkSize: DST->Chunk, loopLowerBound: DST->NextLowerBound,
420 loopUpperBound: DST->LoopUpperBound);
421
422 if (finished == FINISHED)
423 return DISPATCH_FINISHED;
424
425 // not finished (either not finished or last chunk)
426 *plast = (int32_t)(finished == LAST_CHUNK);
427 *plower = myLb;
428 *pupper = myUb;
429 *pstride = 1;
430
431 return DISPATCH_NOTFINISHED;
432 }
433
434 static void dispatch_fini() {
435 // nothing
436 }
437
438 ////////////////////////////////////////////////////////////////////////////////
439 // end of template class that encapsulate all the helper functions
440 ////////////////////////////////////////////////////////////////////////////////
441};
442
443////////////////////////////////////////////////////////////////////////////////
444// KMP interface implementation (dyn loops)
445////////////////////////////////////////////////////////////////////////////////
446
447// TODO: This is a stopgap. We probably want to expand the dispatch API to take
448// an DST pointer which can then be allocated properly without malloc.
449static DynamicScheduleTracker *THREAD_LOCAL(ThreadDSTPtr);
450
451// Create a new DST, link the current one, and define the new as current.
452static DynamicScheduleTracker *pushDST() {
453 DynamicScheduleTracker *NewDST = static_cast<DynamicScheduleTracker *>(
454 memory::allocGlobal(sizeof(DynamicScheduleTracker), "new DST"));
455 *NewDST = DynamicScheduleTracker({0});
456 NewDST->NextDST = ThreadDSTPtr;
457 ThreadDSTPtr = NewDST;
458 return ThreadDSTPtr;
459}
460
461// Return the current DST.
462static DynamicScheduleTracker *peekDST() { return ThreadDSTPtr; }
463
464// Pop the current DST and restore the last one.
465static void popDST() {
466 DynamicScheduleTracker *OldDST = ThreadDSTPtr->NextDST;
467 memory::freeGlobal(ThreadDSTPtr, "remove DST");
468 ThreadDSTPtr = OldDST;
469}
470
471extern "C" {
472
473// init
474void __kmpc_dispatch_init_4(IdentTy *loc, int32_t tid, int32_t schedule,
475 int32_t lb, int32_t ub, int32_t st, int32_t chunk) {
476 DynamicScheduleTracker *DST = pushDST();
477 omptarget_nvptx_LoopSupport<int32_t, int32_t>::dispatch_init(
478 loc, tid, (kmp_sched_t)schedule, lb, ub, st, chunk, DST);
479}
480
481void __kmpc_dispatch_init_4u(IdentTy *loc, int32_t tid, int32_t schedule,
482 uint32_t lb, uint32_t ub, int32_t st,
483 int32_t chunk) {
484 DynamicScheduleTracker *DST = pushDST();
485 omptarget_nvptx_LoopSupport<uint32_t, int32_t>::dispatch_init(
486 loc, tid, (kmp_sched_t)schedule, lb, ub, st, chunk, DST);
487}
488
489void __kmpc_dispatch_init_8(IdentTy *loc, int32_t tid, int32_t schedule,
490 int64_t lb, int64_t ub, int64_t st, int64_t chunk) {
491 DynamicScheduleTracker *DST = pushDST();
492 omptarget_nvptx_LoopSupport<int64_t, int64_t>::dispatch_init(
493 loc, tid, (kmp_sched_t)schedule, lb, ub, st, chunk, DST);
494}
495
496void __kmpc_dispatch_init_8u(IdentTy *loc, int32_t tid, int32_t schedule,
497 uint64_t lb, uint64_t ub, int64_t st,
498 int64_t chunk) {
499 DynamicScheduleTracker *DST = pushDST();
500 omptarget_nvptx_LoopSupport<uint64_t, int64_t>::dispatch_init(
501 loc, tid, (kmp_sched_t)schedule, lb, ub, st, chunk, DST);
502}
503
504// next
505int __kmpc_dispatch_next_4(IdentTy *loc, int32_t tid, int32_t *p_last,
506 int32_t *p_lb, int32_t *p_ub, int32_t *p_st) {
507 DynamicScheduleTracker *DST = peekDST();
508 return omptarget_nvptx_LoopSupport<int32_t, int32_t>::dispatch_next(
509 loc, tid, p_last, p_lb, p_ub, p_st, DST);
510}
511
512int __kmpc_dispatch_next_4u(IdentTy *loc, int32_t tid, int32_t *p_last,
513 uint32_t *p_lb, uint32_t *p_ub, int32_t *p_st) {
514 DynamicScheduleTracker *DST = peekDST();
515 return omptarget_nvptx_LoopSupport<uint32_t, int32_t>::dispatch_next(
516 loc, tid, p_last, p_lb, p_ub, p_st, DST);
517}
518
519int __kmpc_dispatch_next_8(IdentTy *loc, int32_t tid, int32_t *p_last,
520 int64_t *p_lb, int64_t *p_ub, int64_t *p_st) {
521 DynamicScheduleTracker *DST = peekDST();
522 return omptarget_nvptx_LoopSupport<int64_t, int64_t>::dispatch_next(
523 loc, tid, p_last, p_lb, p_ub, p_st, DST);
524}
525
526int __kmpc_dispatch_next_8u(IdentTy *loc, int32_t tid, int32_t *p_last,
527 uint64_t *p_lb, uint64_t *p_ub, int64_t *p_st) {
528 DynamicScheduleTracker *DST = peekDST();
529 return omptarget_nvptx_LoopSupport<uint64_t, int64_t>::dispatch_next(
530 loc, tid, p_last, p_lb, p_ub, p_st, DST);
531}
532
533// fini
534void __kmpc_dispatch_fini_4(IdentTy *loc, int32_t tid) {
535 omptarget_nvptx_LoopSupport<int32_t, int32_t>::dispatch_fini();
536 popDST();
537}
538
539void __kmpc_dispatch_fini_4u(IdentTy *loc, int32_t tid) {
540 omptarget_nvptx_LoopSupport<uint32_t, int32_t>::dispatch_fini();
541 popDST();
542}
543
544void __kmpc_dispatch_fini_8(IdentTy *loc, int32_t tid) {
545 omptarget_nvptx_LoopSupport<int64_t, int64_t>::dispatch_fini();
546 popDST();
547}
548
549void __kmpc_dispatch_fini_8u(IdentTy *loc, int32_t tid) {
550 omptarget_nvptx_LoopSupport<uint64_t, int64_t>::dispatch_fini();
551 popDST();
552}
553
554////////////////////////////////////////////////////////////////////////////////
555// KMP interface implementation (static loops)
556////////////////////////////////////////////////////////////////////////////////
557
558void __kmpc_for_static_init_4(IdentTy *loc, int32_t global_tid,
559 int32_t schedtype, int32_t *plastiter,
560 int32_t *plower, int32_t *pupper,
561 int32_t *pstride, int32_t incr, int32_t chunk) {
562 omptarget_nvptx_LoopSupport<int32_t, int32_t>::for_static_init(
563 global_tid, schedtype, plastiter, plower, pupper, pstride, chunk,
564 mapping::isSPMDMode());
565}
566
567void __kmpc_for_static_init_4u(IdentTy *loc, int32_t global_tid,
568 int32_t schedtype, int32_t *plastiter,
569 uint32_t *plower, uint32_t *pupper,
570 int32_t *pstride, int32_t incr, int32_t chunk) {
571 omptarget_nvptx_LoopSupport<uint32_t, int32_t>::for_static_init(
572 global_tid, schedtype, plastiter, plower, pupper, pstride, chunk,
573 mapping::isSPMDMode());
574}
575
576void __kmpc_for_static_init_8(IdentTy *loc, int32_t global_tid,
577 int32_t schedtype, int32_t *plastiter,
578 int64_t *plower, int64_t *pupper,
579 int64_t *pstride, int64_t incr, int64_t chunk) {
580 omptarget_nvptx_LoopSupport<int64_t, int64_t>::for_static_init(
581 global_tid, schedtype, plastiter, plower, pupper, pstride, chunk,
582 mapping::isSPMDMode());
583}
584
585void __kmpc_for_static_init_8u(IdentTy *loc, int32_t global_tid,
586 int32_t schedtype, int32_t *plastiter,
587 uint64_t *plower, uint64_t *pupper,
588 int64_t *pstride, int64_t incr, int64_t chunk) {
589 omptarget_nvptx_LoopSupport<uint64_t, int64_t>::for_static_init(
590 global_tid, schedtype, plastiter, plower, pupper, pstride, chunk,
591 mapping::isSPMDMode());
592}
593
594void __kmpc_distribute_static_init_4(IdentTy *loc, int32_t global_tid,
595 int32_t schedtype, int32_t *plastiter,
596 int32_t *plower, int32_t *pupper,
597 int32_t *pstride, int32_t incr,
598 int32_t chunk) {
599 omptarget_nvptx_LoopSupport<int32_t, int32_t>::for_static_init(
600 global_tid, schedtype, plastiter, plower, pupper, pstride, chunk,
601 mapping::isSPMDMode());
602}
603
604void __kmpc_distribute_static_init_4u(IdentTy *loc, int32_t global_tid,
605 int32_t schedtype, int32_t *plastiter,
606 uint32_t *plower, uint32_t *pupper,
607 int32_t *pstride, int32_t incr,
608 int32_t chunk) {
609 omptarget_nvptx_LoopSupport<uint32_t, int32_t>::for_static_init(
610 global_tid, schedtype, plastiter, plower, pupper, pstride, chunk,
611 mapping::isSPMDMode());
612}
613
614void __kmpc_distribute_static_init_8(IdentTy *loc, int32_t global_tid,
615 int32_t schedtype, int32_t *plastiter,
616 int64_t *plower, int64_t *pupper,
617 int64_t *pstride, int64_t incr,
618 int64_t chunk) {
619 omptarget_nvptx_LoopSupport<int64_t, int64_t>::for_static_init(
620 global_tid, schedtype, plastiter, plower, pupper, pstride, chunk,
621 mapping::isSPMDMode());
622}
623
624void __kmpc_distribute_static_init_8u(IdentTy *loc, int32_t global_tid,
625 int32_t schedtype, int32_t *plastiter,
626 uint64_t *plower, uint64_t *pupper,
627 int64_t *pstride, int64_t incr,
628 int64_t chunk) {
629 omptarget_nvptx_LoopSupport<uint64_t, int64_t>::for_static_init(
630 global_tid, schedtype, plastiter, plower, pupper, pstride, chunk,
631 mapping::isSPMDMode());
632}
633
634void __kmpc_for_static_fini(IdentTy *loc, int32_t global_tid) {}
635
636void __kmpc_distribute_static_fini(IdentTy *loc, int32_t global_tid) {}
637}
638
639namespace ompx {
640
641/// Helper class to hide the generic loop nest and provide the template argument
642/// throughout.
643template <typename Ty> class StaticLoopChunker {
644
645 /// Generic loop nest that handles block and/or thread distribution in the
646 /// absence of user specified chunk sizes. This implicitly picks a block chunk
647 /// size equal to the number of threads in the block and a thread chunk size
648 /// equal to one. In contrast to the chunked version we can get away with a
649 /// single loop in this case
650 static void NormalizedLoopNestNoChunk(void (*LoopBody)(Ty, void *), void *Arg,
651 Ty NumBlocks, Ty BId, Ty NumThreads,
652 Ty TId, Ty NumIters,
653 bool OneIterationPerThread) {
654 Ty KernelIteration = NumBlocks * NumThreads;
655
656 // Start index in the normalized space.
657 Ty IV = BId * NumThreads + TId;
658 ASSERT(IV >= 0, "Bad index");
659
660 // Cover the entire iteration space, assumptions in the caller might allow
661 // to simplify this loop to a conditional.
662 if (IV < NumIters) {
663 do {
664
665 // Execute the loop body.
666 LoopBody(IV, Arg);
667
668 // Every thread executed one block and thread chunk now.
669 IV += KernelIteration;
670
671 if (OneIterationPerThread)
672 return;
673
674 } while (IV < NumIters);
675 }
676 }
677
678 /// Generic loop nest that handles block and/or thread distribution in the
679 /// presence of user specified chunk sizes (for at least one of them).
680 static void NormalizedLoopNestChunked(void (*LoopBody)(Ty, void *), void *Arg,
681 Ty BlockChunk, Ty NumBlocks, Ty BId,
682 Ty ThreadChunk, Ty NumThreads, Ty TId,
683 Ty NumIters,
684 bool OneIterationPerThread) {
685 Ty KernelIteration = NumBlocks * BlockChunk;
686
687 // Start index in the chunked space.
688 Ty IV = BId * BlockChunk + TId;
689 ASSERT(IV >= 0, "Bad index");
690
691 // Cover the entire iteration space, assumptions in the caller might allow
692 // to simplify this loop to a conditional.
693 do {
694
695 Ty BlockChunkLeft =
696 BlockChunk >= TId * ThreadChunk ? BlockChunk - TId * ThreadChunk : 0;
697 Ty ThreadChunkLeft =
698 ThreadChunk <= BlockChunkLeft ? ThreadChunk : BlockChunkLeft;
699
700 while (ThreadChunkLeft--) {
701
702 // Given the blocking it's hard to keep track of what to execute.
703 if (IV >= NumIters)
704 return;
705
706 // Execute the loop body.
707 LoopBody(IV, Arg);
708
709 if (OneIterationPerThread)
710 return;
711
712 ++IV;
713 }
714
715 IV += KernelIteration;
716
717 } while (IV < NumIters);
718 }
719
720public:
721 /// Worksharing `for`-loop.
722 static void For(IdentTy *Loc, void (*LoopBody)(Ty, void *), void *Arg,
723 Ty NumIters, Ty NumThreads, Ty ThreadChunk) {
724 ASSERT(NumIters >= 0, "Bad iteration count");
725 ASSERT(ThreadChunk >= 0, "Bad thread count");
726
727 // All threads need to participate but we don't know if we are in a
728 // parallel at all or if the user might have used a `num_threads` clause
729 // on the parallel and reduced the number compared to the block size.
730 // Since nested parallels are possible too we need to get the thread id
731 // from the `omp` getter and not the mapping directly.
732 Ty TId = omp_get_thread_num();
733
734 // There are no blocks involved here.
735 Ty BlockChunk = 0;
736 Ty NumBlocks = 1;
737 Ty BId = 0;
738
739 // If the thread chunk is not specified we pick a default now.
740 if (ThreadChunk == 0)
741 ThreadChunk = 1;
742
743 // If we know we have more threads than iterations we can indicate that to
744 // avoid an outer loop.
745 bool OneIterationPerThread = false;
746 if (config::getAssumeThreadsOversubscription()) {
747 ASSERT(NumThreads >= NumIters, "Broken assumption");
748 OneIterationPerThread = true;
749 }
750
751 if (ThreadChunk != 1)
752 NormalizedLoopNestChunked(LoopBody, Arg, BlockChunk, NumBlocks, BId,
753 ThreadChunk, NumThreads, TId, NumIters,
754 OneIterationPerThread);
755 else
756 NormalizedLoopNestNoChunk(LoopBody, Arg, NumBlocks, BId, NumThreads, TId,
757 NumIters, OneIterationPerThread);
758 }
759
760 /// Worksharing `distrbute`-loop.
761 static void Distribute(IdentTy *Loc, void (*LoopBody)(Ty, void *), void *Arg,
762 Ty NumIters, Ty BlockChunk) {
763 ASSERT(icv::Level == 0, "Bad distribute");
764 ASSERT(icv::ActiveLevel == 0, "Bad distribute");
765 ASSERT(state::ParallelRegionFn == nullptr, "Bad distribute");
766 ASSERT(state::ParallelTeamSize == 1, "Bad distribute");
767
768 ASSERT(NumIters >= 0, "Bad iteration count");
769 ASSERT(BlockChunk >= 0, "Bad block count");
770
771 // There are no threads involved here.
772 Ty ThreadChunk = 0;
773 Ty NumThreads = 1;
774 Ty TId = 0;
775 ASSERT(TId == mapping::getThreadIdInBlock(), "Bad thread id");
776
777 // All teams need to participate.
778 Ty NumBlocks = mapping::getNumberOfBlocksInKernel();
779 Ty BId = mapping::getBlockIdInKernel();
780
781 // If the block chunk is not specified we pick a default now.
782 if (BlockChunk == 0)
783 BlockChunk = NumThreads;
784
785 // If we know we have more blocks than iterations we can indicate that to
786 // avoid an outer loop.
787 bool OneIterationPerThread = false;
788 if (config::getAssumeTeamsOversubscription()) {
789 ASSERT(NumBlocks >= NumIters, "Broken assumption");
790 OneIterationPerThread = true;
791 }
792
793 if (BlockChunk != NumThreads)
794 NormalizedLoopNestChunked(LoopBody, Arg, BlockChunk, NumBlocks, BId,
795 ThreadChunk, NumThreads, TId, NumIters,
796 OneIterationPerThread);
797 else
798 NormalizedLoopNestNoChunk(LoopBody, Arg, NumBlocks, BId, NumThreads, TId,
799 NumIters, OneIterationPerThread);
800
801 ASSERT(icv::Level == 0, "Bad distribute");
802 ASSERT(icv::ActiveLevel == 0, "Bad distribute");
803 ASSERT(state::ParallelRegionFn == nullptr, "Bad distribute");
804 ASSERT(state::ParallelTeamSize == 1, "Bad distribute");
805 }
806
807 /// Worksharing `distrbute parallel for`-loop.
808 static void DistributeFor(IdentTy *Loc, void (*LoopBody)(Ty, void *),
809 void *Arg, Ty NumIters, Ty NumThreads,
810 Ty BlockChunk, Ty ThreadChunk) {
811 ASSERT(icv::Level == 1, "Bad distribute");
812 ASSERT(icv::ActiveLevel == 1, "Bad distribute");
813 ASSERT(state::ParallelRegionFn == nullptr, "Bad distribute");
814
815 ASSERT(NumIters >= 0, "Bad iteration count");
816 ASSERT(BlockChunk >= 0, "Bad block count");
817 ASSERT(ThreadChunk >= 0, "Bad thread count");
818
819 // All threads need to participate but the user might have used a
820 // `num_threads` clause on the parallel and reduced the number compared to
821 // the block size.
822 Ty TId = mapping::getThreadIdInBlock();
823
824 // All teams need to participate.
825 Ty NumBlocks = mapping::getNumberOfBlocksInKernel();
826 Ty BId = mapping::getBlockIdInKernel();
827
828 // If the block chunk is not specified we pick a default now.
829 if (BlockChunk == 0)
830 BlockChunk = NumThreads;
831
832 // If the thread chunk is not specified we pick a default now.
833 if (ThreadChunk == 0)
834 ThreadChunk = 1;
835
836 // If we know we have more threads (across all blocks) than iterations we
837 // can indicate that to avoid an outer loop.
838 bool OneIterationPerThread = false;
839 if (config::getAssumeTeamsOversubscription() &
840 config::getAssumeThreadsOversubscription()) {
841 OneIterationPerThread = true;
842 ASSERT(NumBlocks * NumThreads >= NumIters, "Broken assumption");
843 }
844
845 if (BlockChunk != NumThreads || ThreadChunk != 1)
846 NormalizedLoopNestChunked(LoopBody, Arg, BlockChunk, NumBlocks, BId,
847 ThreadChunk, NumThreads, TId, NumIters,
848 OneIterationPerThread);
849 else
850 NormalizedLoopNestNoChunk(LoopBody, Arg, NumBlocks, BId, NumThreads, TId,
851 NumIters, OneIterationPerThread);
852
853 ASSERT(icv::Level == 1, "Bad distribute");
854 ASSERT(icv::ActiveLevel == 1, "Bad distribute");
855 ASSERT(state::ParallelRegionFn == nullptr, "Bad distribute");
856 }
857};
858
859} // namespace ompx
860
861#define OMP_LOOP_ENTRY(BW, TY) \
862 [[gnu::flatten, clang::always_inline]] void \
863 __kmpc_distribute_for_static_loop##BW( \
864 IdentTy *loc, void (*fn)(TY, void *), void *arg, TY num_iters, \
865 TY num_threads, TY block_chunk, TY thread_chunk) { \
866 ompx::StaticLoopChunker<TY>::DistributeFor( \
867 loc, fn, arg, num_iters + 1, num_threads, block_chunk, thread_chunk); \
868 } \
869 [[gnu::flatten, clang::always_inline]] void \
870 __kmpc_distribute_static_loop##BW(IdentTy *loc, void (*fn)(TY, void *), \
871 void *arg, TY num_iters, \
872 TY block_chunk) { \
873 ompx::StaticLoopChunker<TY>::Distribute(loc, fn, arg, num_iters + 1, \
874 block_chunk); \
875 } \
876 [[gnu::flatten, clang::always_inline]] void __kmpc_for_static_loop##BW( \
877 IdentTy *loc, void (*fn)(TY, void *), void *arg, TY num_iters, \
878 TY num_threads, TY thread_chunk) { \
879 ompx::StaticLoopChunker<TY>::For(loc, fn, arg, num_iters + 1, num_threads, \
880 thread_chunk); \
881 }
882
883extern "C" {
884OMP_LOOP_ENTRY(_4, int32_t)
885OMP_LOOP_ENTRY(_4u, uint32_t)
886OMP_LOOP_ENTRY(_8, int64_t)
887OMP_LOOP_ENTRY(_8u, uint64_t)
888}
889
890#pragma omp end declare target
891

source code of offload/DeviceRTL/src/Workshare.cpp