1//===-- lib/Semantics/check-cuda.cpp ----------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "check-cuda.h"
10#include "flang/Common/template.h"
11#include "flang/Evaluate/fold.h"
12#include "flang/Evaluate/tools.h"
13#include "flang/Evaluate/traverse.h"
14#include "flang/Parser/parse-tree-visitor.h"
15#include "flang/Parser/parse-tree.h"
16#include "flang/Parser/tools.h"
17#include "flang/Semantics/expression.h"
18#include "flang/Semantics/symbol.h"
19#include "flang/Semantics/tools.h"
20#include "llvm/ADT/StringSet.h"
21
22// Once labeled DO constructs have been canonicalized and their parse subtrees
23// transformed into parser::DoConstructs, scan the parser::Blocks of the program
24// and merge adjacent CUFKernelDoConstructs and DoConstructs whenever the
25// CUFKernelDoConstruct doesn't already have an embedded DoConstruct. Also
26// emit errors about improper or missing DoConstructs.
27
28namespace Fortran::parser {
29struct Mutator {
30 template <typename A> bool Pre(A &) { return true; }
31 template <typename A> void Post(A &) {}
32 bool Pre(Block &);
33};
34
35bool Mutator::Pre(Block &block) {
36 for (auto iter{block.begin()}; iter != block.end(); ++iter) {
37 if (auto *kernel{Unwrap<CUFKernelDoConstruct>(*iter)}) {
38 auto &nested{std::get<std::optional<DoConstruct>>(kernel->t)};
39 if (!nested) {
40 if (auto next{iter}; ++next != block.end()) {
41 if (auto *doConstruct{Unwrap<DoConstruct>(*next)}) {
42 nested = std::move(*doConstruct);
43 block.erase(next);
44 }
45 }
46 }
47 } else {
48 Walk(*iter, *this);
49 }
50 }
51 return false;
52}
53} // namespace Fortran::parser
54
55namespace Fortran::semantics {
56
57bool CanonicalizeCUDA(parser::Program &program) {
58 parser::Mutator mutator;
59 parser::Walk(program, mutator);
60 return true;
61}
62
63using MaybeMsg = std::optional<parser::MessageFormattedText>;
64
65static const llvm::StringSet<> warpFunctions_ = {"match_all_syncjj",
66 "match_all_syncjx", "match_all_syncjf", "match_all_syncjd",
67 "match_any_syncjj", "match_any_syncjx", "match_any_syncjf",
68 "match_any_syncjd"};
69
70// Traverses an evaluate::Expr<> in search of unsupported operations
71// on the device.
72
73struct DeviceExprChecker
74 : public evaluate::AnyTraverse<DeviceExprChecker, MaybeMsg> {
75 using Result = MaybeMsg;
76 using Base = evaluate::AnyTraverse<DeviceExprChecker, Result>;
77 explicit DeviceExprChecker(SemanticsContext &c) : Base(*this), context_{c} {}
78 using Base::operator();
79 Result operator()(const evaluate::ProcedureDesignator &x) const {
80 if (const Symbol * sym{x.GetInterfaceSymbol()}) {
81 const auto *subp{
82 sym->GetUltimate().detailsIf<semantics::SubprogramDetails>()};
83 if (subp) {
84 if (auto attrs{subp->cudaSubprogramAttrs()}) {
85 if (*attrs == common::CUDASubprogramAttrs::HostDevice ||
86 *attrs == common::CUDASubprogramAttrs::Device) {
87 if (warpFunctions_.contains(sym->name().ToString()) &&
88 !context_.languageFeatures().IsEnabled(
89 Fortran::common::LanguageFeature::CudaWarpMatchFunction)) {
90 return parser::MessageFormattedText(
91 "warp match function disabled"_err_en_US);
92 }
93 return {};
94 }
95 }
96 }
97
98 const Symbol &ultimate{sym->GetUltimate()};
99 const Scope &scope{ultimate.owner()};
100 const Symbol *mod{scope.IsModule() ? scope.symbol() : nullptr};
101 // Allow ieee_arithmetic module functions to be called on the device.
102 // TODO: Check for unsupported ieee_arithmetic on the device.
103 if (mod && mod->name() == "ieee_arithmetic") {
104 return {};
105 }
106 } else if (x.GetSpecificIntrinsic()) {
107 // TODO(CUDA): Check for unsupported intrinsics here
108 return {};
109 }
110
111 return parser::MessageFormattedText(
112 "'%s' may not be called in device code"_err_en_US, x.GetName());
113 }
114
115 SemanticsContext &context_;
116};
117
118struct FindHostArray
119 : public evaluate::AnyTraverse<FindHostArray, const Symbol *> {
120 using Result = const Symbol *;
121 using Base = evaluate::AnyTraverse<FindHostArray, Result>;
122 FindHostArray() : Base(*this) {}
123 using Base::operator();
124 Result operator()(const evaluate::Component &x) const {
125 const Symbol &symbol{x.GetLastSymbol()};
126 if (IsAllocatableOrPointer(symbol)) {
127 if (Result hostArray{(*this)(symbol)}) {
128 return hostArray;
129 }
130 }
131 return (*this)(x.base());
132 }
133 Result operator()(const Symbol &symbol) const {
134 if (const auto *details{
135 symbol.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()}) {
136 if (details->IsArray() &&
137 !symbol.attrs().test(Fortran::semantics::Attr::PARAMETER) &&
138 (!details->cudaDataAttr() ||
139 (details->cudaDataAttr() &&
140 *details->cudaDataAttr() != common::CUDADataAttr::Device &&
141 *details->cudaDataAttr() != common::CUDADataAttr::Constant &&
142 *details->cudaDataAttr() != common::CUDADataAttr::Managed &&
143 *details->cudaDataAttr() != common::CUDADataAttr::Shared &&
144 *details->cudaDataAttr() != common::CUDADataAttr::Unified))) {
145 return &symbol;
146 }
147 }
148 return nullptr;
149 }
150};
151
152template <typename A>
153static MaybeMsg CheckUnwrappedExpr(SemanticsContext &context, const A &x) {
154 if (const auto *expr{parser::Unwrap<parser::Expr>(x)}) {
155 return DeviceExprChecker{context}(expr->typedExpr);
156 }
157 return {};
158}
159
160template <typename A>
161static void CheckUnwrappedExpr(
162 SemanticsContext &context, SourceName at, const A &x) {
163 if (const auto *expr{parser::Unwrap<parser::Expr>(x)}) {
164 if (auto msg{DeviceExprChecker{context}(expr->typedExpr)}) {
165 context.Say(at, std::move(*msg));
166 }
167 }
168}
169
170template <bool CUF_KERNEL> struct ActionStmtChecker {
171 template <typename A>
172 static MaybeMsg WhyNotOk(SemanticsContext &context, const A &x) {
173 if constexpr (ConstraintTrait<A>) {
174 return WhyNotOk(context, x.thing);
175 } else if constexpr (WrapperTrait<A>) {
176 return WhyNotOk(context, x.v);
177 } else if constexpr (UnionTrait<A>) {
178 return WhyNotOk(context, x.u);
179 } else if constexpr (TupleTrait<A>) {
180 return WhyNotOk(context, x.t);
181 } else {
182 return parser::MessageFormattedText{
183 "Statement may not appear in device code"_err_en_US};
184 }
185 }
186 template <typename A>
187 static MaybeMsg WhyNotOk(
188 SemanticsContext &context, const common::Indirection<A> &x) {
189 return WhyNotOk(context, x.value());
190 }
191 template <typename... As>
192 static MaybeMsg WhyNotOk(
193 SemanticsContext &context, const std::variant<As...> &x) {
194 return common::visit(
195 [&context](const auto &x) { return WhyNotOk(context, x); }, x);
196 }
197 template <std::size_t J = 0, typename... As>
198 static MaybeMsg WhyNotOk(
199 SemanticsContext &context, const std::tuple<As...> &x) {
200 if constexpr (J == sizeof...(As)) {
201 return {};
202 } else if (auto msg{WhyNotOk(context, std::get<J>(x))}) {
203 return msg;
204 } else {
205 return WhyNotOk<(J + 1)>(context, x);
206 }
207 }
208 template <typename A>
209 static MaybeMsg WhyNotOk(SemanticsContext &context, const std::list<A> &x) {
210 for (const auto &y : x) {
211 if (MaybeMsg result{WhyNotOk(context, y)}) {
212 return result;
213 }
214 }
215 return {};
216 }
217 template <typename A>
218 static MaybeMsg WhyNotOk(
219 SemanticsContext &context, const std::optional<A> &x) {
220 if (x) {
221 return WhyNotOk(context, *x);
222 } else {
223 return {};
224 }
225 }
226 template <typename A>
227 static MaybeMsg WhyNotOk(
228 SemanticsContext &context, const parser::UnlabeledStatement<A> &x) {
229 return WhyNotOk(context, x.statement);
230 }
231 template <typename A>
232 static MaybeMsg WhyNotOk(
233 SemanticsContext &context, const parser::Statement<A> &x) {
234 return WhyNotOk(context, x.statement);
235 }
236 static MaybeMsg WhyNotOk(
237 SemanticsContext &context, const parser::AllocateStmt &) {
238 return {}; // AllocateObjects are checked elsewhere
239 }
240 static MaybeMsg WhyNotOk(
241 SemanticsContext &context, const parser::AllocateCoarraySpec &) {
242 return parser::MessageFormattedText(
243 "A coarray may not be allocated on the device"_err_en_US);
244 }
245 static MaybeMsg WhyNotOk(
246 SemanticsContext &context, const parser::DeallocateStmt &) {
247 return {}; // AllocateObjects are checked elsewhere
248 }
249 static MaybeMsg WhyNotOk(
250 SemanticsContext &context, const parser::AssignmentStmt &x) {
251 return DeviceExprChecker{context}(x.typedAssignment);
252 }
253 static MaybeMsg WhyNotOk(
254 SemanticsContext &context, const parser::CallStmt &x) {
255 return DeviceExprChecker{context}(x.typedCall);
256 }
257 static MaybeMsg WhyNotOk(
258 SemanticsContext &context, const parser::ContinueStmt &) {
259 return {};
260 }
261 static MaybeMsg WhyNotOk(SemanticsContext &context, const parser::IfStmt &x) {
262 if (auto result{CheckUnwrappedExpr(
263 context, std::get<parser::ScalarLogicalExpr>(x.t))}) {
264 return result;
265 }
266 return WhyNotOk(context,
267 std::get<parser::UnlabeledStatement<parser::ActionStmt>>(x.t)
268 .statement);
269 }
270 static MaybeMsg WhyNotOk(
271 SemanticsContext &context, const parser::NullifyStmt &x) {
272 for (const auto &y : x.v) {
273 if (MaybeMsg result{DeviceExprChecker{context}(y.typedExpr)}) {
274 return result;
275 }
276 }
277 return {};
278 }
279 static MaybeMsg WhyNotOk(
280 SemanticsContext &context, const parser::PointerAssignmentStmt &x) {
281 return DeviceExprChecker{context}(x.typedAssignment);
282 }
283};
284
285template <bool IsCUFKernelDo> class DeviceContextChecker {
286public:
287 explicit DeviceContextChecker(SemanticsContext &c) : context_{c} {}
288 void CheckSubprogram(const parser::Name &name, const parser::Block &body) {
289 if (name.symbol) {
290 const auto *subp{
291 name.symbol->GetUltimate().detailsIf<SubprogramDetails>()};
292 if (subp && subp->moduleInterface()) {
293 subp = subp->moduleInterface()
294 ->GetUltimate()
295 .detailsIf<SubprogramDetails>();
296 }
297 if (subp &&
298 subp->cudaSubprogramAttrs().value_or(
299 common::CUDASubprogramAttrs::Host) !=
300 common::CUDASubprogramAttrs::Host) {
301 isHostDevice = subp->cudaSubprogramAttrs() &&
302 subp->cudaSubprogramAttrs() ==
303 common::CUDASubprogramAttrs::HostDevice;
304 Check(body);
305 }
306 }
307 }
308 void Check(const parser::Block &block) {
309 for (const auto &epc : block) {
310 Check(epc);
311 }
312 }
313
314private:
315 void Check(const parser::ExecutionPartConstruct &epc) {
316 common::visit(
317 common::visitors{
318 [&](const parser::ExecutableConstruct &x) { Check(x); },
319 [&](const parser::Statement<common::Indirection<parser::EntryStmt>>
320 &x) {
321 context_.Say(x.source,
322 "Device code may not contain an ENTRY statement"_err_en_US);
323 },
324 [](const parser::Statement<common::Indirection<parser::FormatStmt>>
325 &) {},
326 [](const parser::Statement<common::Indirection<parser::DataStmt>>
327 &) {},
328 [](const parser::Statement<
329 common::Indirection<parser::NamelistStmt>> &) {},
330 [](const parser::ErrorRecovery &) {},
331 },
332 epc.u);
333 }
334 void Check(const parser::ExecutableConstruct &ec) {
335 common::visit(
336 common::visitors{
337 [&](const parser::Statement<parser::ActionStmt> &stmt) {
338 Check(stmt.statement, stmt.source);
339 },
340 [&](const common::Indirection<parser::DoConstruct> &x) {
341 if (const std::optional<parser::LoopControl> &control{
342 x.value().GetLoopControl()}) {
343 common::visit([&](const auto &y) { Check(y); }, control->u);
344 }
345 Check(std::get<parser::Block>(x.value().t));
346 },
347 [&](const common::Indirection<parser::BlockConstruct> &x) {
348 Check(std::get<parser::Block>(x.value().t));
349 },
350 [&](const common::Indirection<parser::IfConstruct> &x) {
351 Check(x.value());
352 },
353 [&](const common::Indirection<parser::CaseConstruct> &x) {
354 const auto &caseList{
355 std::get<std::list<parser::CaseConstruct::Case>>(
356 x.value().t)};
357 for (const parser::CaseConstruct::Case &c : caseList) {
358 Check(std::get<parser::Block>(c.t));
359 }
360 },
361 [&](const common::Indirection<parser::CompilerDirective> &x) {
362 // TODO(CUDA): Check for unsupported compiler directive here.
363 },
364 [&](const auto &x) {
365 if (auto source{parser::GetSource(x)}) {
366 context_.Say(*source,
367 "Statement may not appear in device code"_err_en_US);
368 }
369 },
370 },
371 ec.u);
372 }
373 template <typename SEEK, typename A>
374 static const SEEK *GetIOControl(const A &stmt) {
375 for (const auto &spec : stmt.controls) {
376 if (const auto *result{std::get_if<SEEK>(&spec.u)}) {
377 return result;
378 }
379 }
380 return nullptr;
381 }
382 template <typename A> static bool IsInternalIO(const A &stmt) {
383 if (stmt.iounit.has_value()) {
384 return std::holds_alternative<Fortran::parser::Variable>(stmt.iounit->u);
385 }
386 if (auto *unit{GetIOControl<Fortran::parser::IoUnit>(stmt)}) {
387 return std::holds_alternative<Fortran::parser::Variable>(unit->u);
388 }
389 return false;
390 }
391 void WarnOnIoStmt(const parser::CharBlock &source) {
392 context_.Warn(common::UsageWarning::CUDAUsage, source,
393 "I/O statement might not be supported on device"_warn_en_US);
394 }
395 template <typename A>
396 void WarnIfNotInternal(const A &stmt, const parser::CharBlock &source) {
397 if (!IsInternalIO(stmt)) {
398 WarnOnIoStmt(source);
399 }
400 }
401 template <typename A>
402 void ErrorIfHostSymbol(const A &expr, parser::CharBlock source) {
403 if (isHostDevice)
404 return;
405 if (const Symbol * hostArray{FindHostArray{}(expr)}) {
406 context_.Say(source,
407 "Host array '%s' cannot be present in device context"_err_en_US,
408 hostArray->name());
409 }
410 }
411 void ErrorInCUFKernel(parser::CharBlock source) {
412 if (IsCUFKernelDo) {
413 context_.Say(
414 source, "Statement may not appear in cuf kernel code"_err_en_US);
415 }
416 }
417 void Check(const parser::ActionStmt &stmt, const parser::CharBlock &source) {
418 common::visit(
419 common::visitors{
420 [&](const common::Indirection<parser::CycleStmt> &) {
421 ErrorInCUFKernel(source);
422 },
423 [&](const common::Indirection<parser::ExitStmt> &) {
424 ErrorInCUFKernel(source);
425 },
426 [&](const common::Indirection<parser::GotoStmt> &) {
427 ErrorInCUFKernel(source);
428 },
429 [&](const common::Indirection<parser::StopStmt> &) { return; },
430 [&](const common::Indirection<parser::PrintStmt> &) {},
431 [&](const common::Indirection<parser::WriteStmt> &x) {
432 if (x.value().format) { // Formatted write to '*' or '6'
433 if (std::holds_alternative<Fortran::parser::Star>(
434 x.value().format->u)) {
435 if (x.value().iounit) {
436 if (std::holds_alternative<Fortran::parser::Star>(
437 x.value().iounit->u)) {
438 return;
439 }
440 }
441 }
442 }
443 WarnIfNotInternal(x.value(), source);
444 },
445 [&](const common::Indirection<parser::CloseStmt> &x) {
446 WarnOnIoStmt(source);
447 },
448 [&](const common::Indirection<parser::EndfileStmt> &x) {
449 WarnOnIoStmt(source);
450 },
451 [&](const common::Indirection<parser::OpenStmt> &x) {
452 WarnOnIoStmt(source);
453 },
454 [&](const common::Indirection<parser::ReadStmt> &x) {
455 WarnIfNotInternal(x.value(), source);
456 },
457 [&](const common::Indirection<parser::InquireStmt> &x) {
458 WarnOnIoStmt(source);
459 },
460 [&](const common::Indirection<parser::RewindStmt> &x) {
461 WarnOnIoStmt(source);
462 },
463 [&](const common::Indirection<parser::BackspaceStmt> &x) {
464 WarnOnIoStmt(source);
465 },
466 [&](const common::Indirection<parser::IfStmt> &x) {
467 Check(x.value());
468 },
469 [&](const common::Indirection<parser::AssignmentStmt> &x) {
470 if (const evaluate::Assignment *
471 assign{semantics::GetAssignment(x.value())}) {
472 ErrorIfHostSymbol(assign->lhs, source);
473 ErrorIfHostSymbol(assign->rhs, source);
474 }
475 if (auto msg{ActionStmtChecker<IsCUFKernelDo>::WhyNotOk(
476 context_, x)}) {
477 context_.Say(source, std::move(*msg));
478 }
479 },
480 [&](const auto &x) {
481 if (auto msg{ActionStmtChecker<IsCUFKernelDo>::WhyNotOk(
482 context_, x)}) {
483 context_.Say(source, std::move(*msg));
484 }
485 },
486 },
487 stmt.u);
488 }
489 void Check(const parser::IfConstruct &ic) {
490 const auto &ifS{std::get<parser::Statement<parser::IfThenStmt>>(ic.t)};
491 CheckUnwrappedExpr(context_, ifS.source,
492 std::get<parser::ScalarLogicalExpr>(ifS.statement.t));
493 Check(std::get<parser::Block>(ic.t));
494 for (const auto &eib :
495 std::get<std::list<parser::IfConstruct::ElseIfBlock>>(ic.t)) {
496 const auto &eIfS{std::get<parser::Statement<parser::ElseIfStmt>>(eib.t)};
497 CheckUnwrappedExpr(context_, eIfS.source,
498 std::get<parser::ScalarLogicalExpr>(eIfS.statement.t));
499 Check(std::get<parser::Block>(eib.t));
500 }
501 if (const auto &eb{
502 std::get<std::optional<parser::IfConstruct::ElseBlock>>(ic.t)}) {
503 Check(std::get<parser::Block>(eb->t));
504 }
505 }
506 void Check(const parser::IfStmt &is) {
507 const auto &uS{
508 std::get<parser::UnlabeledStatement<parser::ActionStmt>>(is.t)};
509 CheckUnwrappedExpr(
510 context_, uS.source, std::get<parser::ScalarLogicalExpr>(is.t));
511 Check(uS.statement, uS.source);
512 }
513 void Check(const parser::LoopControl::Bounds &bounds) {
514 Check(bounds.lower);
515 Check(bounds.upper);
516 if (bounds.step) {
517 Check(*bounds.step);
518 }
519 }
520 void Check(const parser::LoopControl::Concurrent &x) {
521 const auto &header{std::get<parser::ConcurrentHeader>(x.t)};
522 for (const auto &cc :
523 std::get<std::list<parser::ConcurrentControl>>(header.t)) {
524 Check(std::get<1>(cc.t));
525 Check(std::get<2>(cc.t));
526 if (const auto &step{
527 std::get<std::optional<parser::ScalarIntExpr>>(cc.t)}) {
528 Check(*step);
529 }
530 }
531 if (const auto &mask{
532 std::get<std::optional<parser::ScalarLogicalExpr>>(header.t)}) {
533 Check(*mask);
534 }
535 }
536 void Check(const parser::ScalarLogicalExpr &x) {
537 Check(DEREF(parser::Unwrap<parser::Expr>(x)));
538 }
539 void Check(const parser::ScalarIntExpr &x) {
540 Check(DEREF(parser::Unwrap<parser::Expr>(x)));
541 }
542 void Check(const parser::ScalarExpr &x) {
543 Check(DEREF(parser::Unwrap<parser::Expr>(x)));
544 }
545 void Check(const parser::Expr &expr) {
546 if (MaybeMsg msg{DeviceExprChecker{context_}(expr.typedExpr)}) {
547 context_.Say(expr.source, std::move(*msg));
548 }
549 }
550
551 SemanticsContext &context_;
552 bool isHostDevice{false};
553};
554
555void CUDAChecker::Enter(const parser::SubroutineSubprogram &x) {
556 DeviceContextChecker<false>{context_}.CheckSubprogram(
557 std::get<parser::Name>(
558 std::get<parser::Statement<parser::SubroutineStmt>>(x.t).statement.t),
559 std::get<parser::ExecutionPart>(x.t).v);
560}
561
562void CUDAChecker::Enter(const parser::FunctionSubprogram &x) {
563 DeviceContextChecker<false>{context_}.CheckSubprogram(
564 std::get<parser::Name>(
565 std::get<parser::Statement<parser::FunctionStmt>>(x.t).statement.t),
566 std::get<parser::ExecutionPart>(x.t).v);
567}
568
569void CUDAChecker::Enter(const parser::SeparateModuleSubprogram &x) {
570 DeviceContextChecker<false>{context_}.CheckSubprogram(
571 std::get<parser::Statement<parser::MpSubprogramStmt>>(x.t).statement.v,
572 std::get<parser::ExecutionPart>(x.t).v);
573}
574
575// !$CUF KERNEL DO semantic checks
576
577static int DoConstructTightNesting(
578 const parser::DoConstruct *doConstruct, const parser::Block *&innerBlock) {
579 if (!doConstruct ||
580 (!doConstruct->IsDoNormal() && !doConstruct->IsDoConcurrent())) {
581 return 0;
582 }
583 innerBlock = &std::get<parser::Block>(doConstruct->t);
584 if (doConstruct->IsDoConcurrent()) {
585 const auto &loopControl = doConstruct->GetLoopControl();
586 if (loopControl) {
587 if (const auto *concurrentControl{
588 std::get_if<parser::LoopControl::Concurrent>(&loopControl->u)}) {
589 const auto &concurrentHeader =
590 std::get<Fortran::parser::ConcurrentHeader>(concurrentControl->t);
591 const auto &controls =
592 std::get<std::list<Fortran::parser::ConcurrentControl>>(
593 concurrentHeader.t);
594 return controls.size();
595 }
596 }
597 return 0;
598 }
599 if (innerBlock->size() == 1) {
600 if (const auto *execConstruct{
601 std::get_if<parser::ExecutableConstruct>(&innerBlock->front().u)}) {
602 if (const auto *next{
603 std::get_if<common::Indirection<parser::DoConstruct>>(
604 &execConstruct->u)}) {
605 return 1 + DoConstructTightNesting(&next->value(), innerBlock);
606 }
607 }
608 }
609 return 1;
610}
611
612static void CheckReduce(
613 SemanticsContext &context, const parser::CUFReduction &reduce) {
614 auto op{std::get<parser::CUFReduction::Operator>(reduce.t).v};
615 for (const auto &var :
616 std::get<std::list<parser::Scalar<parser::Variable>>>(reduce.t)) {
617 if (const auto &typedExprPtr{var.thing.typedExpr};
618 typedExprPtr && typedExprPtr->v) {
619 const auto &expr{*typedExprPtr->v};
620 if (auto type{expr.GetType()}) {
621 auto cat{type->category()};
622 bool isOk{false};
623 switch (op) {
624 case parser::ReductionOperator::Operator::Plus:
625 case parser::ReductionOperator::Operator::Multiply:
626 case parser::ReductionOperator::Operator::Max:
627 case parser::ReductionOperator::Operator::Min:
628 isOk = cat == TypeCategory::Integer || cat == TypeCategory::Real ||
629 cat == TypeCategory::Complex;
630 break;
631 case parser::ReductionOperator::Operator::Iand:
632 case parser::ReductionOperator::Operator::Ior:
633 case parser::ReductionOperator::Operator::Ieor:
634 isOk = cat == TypeCategory::Integer;
635 break;
636 case parser::ReductionOperator::Operator::And:
637 case parser::ReductionOperator::Operator::Or:
638 case parser::ReductionOperator::Operator::Eqv:
639 case parser::ReductionOperator::Operator::Neqv:
640 isOk = cat == TypeCategory::Logical;
641 break;
642 }
643 if (!isOk) {
644 context.Say(var.thing.GetSource(),
645 "!$CUF KERNEL DO REDUCE operation is not acceptable for a variable with type %s"_err_en_US,
646 type->AsFortran());
647 }
648 }
649 }
650 }
651}
652
653void CUDAChecker::Enter(const parser::CUFKernelDoConstruct &x) {
654 auto source{std::get<parser::CUFKernelDoConstruct::Directive>(x.t).source};
655 const auto &directive{std::get<parser::CUFKernelDoConstruct::Directive>(x.t)};
656 std::int64_t depth{1};
657 if (auto expr{AnalyzeExpr(context_,
658 std::get<std::optional<parser::ScalarIntConstantExpr>>(
659 directive.t))}) {
660 depth = evaluate::ToInt64(expr).value_or(0);
661 if (depth <= 0) {
662 context_.Say(source,
663 "!$CUF KERNEL DO (%jd): loop nesting depth must be positive"_err_en_US,
664 std::intmax_t{depth});
665 depth = 1;
666 }
667 }
668 const parser::DoConstruct *doConstruct{common::GetPtrFromOptional(
669 std::get<std::optional<parser::DoConstruct>>(x.t))};
670 const parser::Block *innerBlock{nullptr};
671 if (DoConstructTightNesting(doConstruct, innerBlock) < depth) {
672 if (doConstruct && doConstruct->IsDoConcurrent())
673 context_.Say(source,
674 "!$CUF KERNEL DO (%jd) must be followed by a DO CONCURRENT construct with at least %jd indices"_err_en_US,
675 std::intmax_t{depth}, std::intmax_t{depth});
676 else
677 context_.Say(source,
678 "!$CUF KERNEL DO (%jd) must be followed by a DO construct with tightly nested outer levels of counted DO loops"_err_en_US,
679 std::intmax_t{depth});
680 }
681 if (innerBlock) {
682 DeviceContextChecker<true>{context_}.Check(*innerBlock);
683 }
684 for (const auto &reduce :
685 std::get<std::list<parser::CUFReduction>>(directive.t)) {
686 CheckReduce(context_, reduce);
687 }
688 inCUFKernelDoConstruct_ = true;
689}
690
691void CUDAChecker::Leave(const parser::CUFKernelDoConstruct &) {
692 inCUFKernelDoConstruct_ = false;
693}
694
695void CUDAChecker::Enter(const parser::AssignmentStmt &x) {
696 auto lhsLoc{std::get<parser::Variable>(x.t).GetSource()};
697 const auto &scope{context_.FindScope(lhsLoc)};
698 const Scope &progUnit{GetProgramUnitContaining(scope)};
699 if (IsCUDADeviceContext(&progUnit) || inCUFKernelDoConstruct_) {
700 return; // Data transfer with assignment is only perform on host.
701 }
702
703 const evaluate::Assignment *assign{semantics::GetAssignment(x)};
704 if (!assign) {
705 return;
706 }
707
708 int nbLhs{evaluate::GetNbOfCUDADeviceSymbols(assign->lhs)};
709 int nbRhs{evaluate::GetNbOfCUDADeviceSymbols(assign->rhs)};
710
711 // device to host transfer with more than one device object on the rhs is not
712 // legal.
713 if (nbLhs == 0 && nbRhs > 1) {
714 context_.Say(lhsLoc,
715 "More than one reference to a CUDA object on the right hand side of the assigment"_err_en_US);
716 }
717}
718
719} // namespace Fortran::semantics
720

source code of flang/lib/Semantics/check-cuda.cpp