1//===-- lib/Semantics/check-cuda.cpp ----------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "check-cuda.h"
10#include "flang/Common/template.h"
11#include "flang/Evaluate/fold.h"
12#include "flang/Evaluate/tools.h"
13#include "flang/Evaluate/traverse.h"
14#include "flang/Parser/parse-tree-visitor.h"
15#include "flang/Parser/parse-tree.h"
16#include "flang/Parser/tools.h"
17#include "flang/Semantics/expression.h"
18#include "flang/Semantics/symbol.h"
19#include "flang/Semantics/tools.h"
20
21// Once labeled DO constructs have been canonicalized and their parse subtrees
22// transformed into parser::DoConstructs, scan the parser::Blocks of the program
23// and merge adjacent CUFKernelDoConstructs and DoConstructs whenever the
24// CUFKernelDoConstruct doesn't already have an embedded DoConstruct. Also
25// emit errors about improper or missing DoConstructs.
26
27namespace Fortran::parser {
28struct Mutator {
29 template <typename A> bool Pre(A &) { return true; }
30 template <typename A> void Post(A &) {}
31 bool Pre(Block &);
32};
33
34bool Mutator::Pre(Block &block) {
35 for (auto iter{block.begin()}; iter != block.end(); ++iter) {
36 if (auto *kernel{Unwrap<CUFKernelDoConstruct>(*iter)}) {
37 auto &nested{std::get<std::optional<DoConstruct>>(kernel->t)};
38 if (!nested) {
39 if (auto next{iter}; ++next != block.end()) {
40 if (auto *doConstruct{Unwrap<DoConstruct>(*next)}) {
41 nested = std::move(*doConstruct);
42 block.erase(next);
43 }
44 }
45 }
46 } else {
47 Walk(*iter, *this);
48 }
49 }
50 return false;
51}
52} // namespace Fortran::parser
53
54namespace Fortran::semantics {
55
56bool CanonicalizeCUDA(parser::Program &program) {
57 parser::Mutator mutator;
58 parser::Walk(program, mutator);
59 return true;
60}
61
62using MaybeMsg = std::optional<parser::MessageFormattedText>;
63
64// Traverses an evaluate::Expr<> in search of unsupported operations
65// on the device.
66
67struct DeviceExprChecker
68 : public evaluate::AnyTraverse<DeviceExprChecker, MaybeMsg> {
69 using Result = MaybeMsg;
70 using Base = evaluate::AnyTraverse<DeviceExprChecker, Result>;
71 DeviceExprChecker() : Base(*this) {}
72 using Base::operator();
73 Result operator()(const evaluate::ProcedureDesignator &x) const {
74 if (const Symbol * sym{x.GetInterfaceSymbol()}) {
75 const auto *subp{
76 sym->GetUltimate().detailsIf<semantics::SubprogramDetails>()};
77 if (subp) {
78 if (auto attrs{subp->cudaSubprogramAttrs()}) {
79 if (*attrs == common::CUDASubprogramAttrs::HostDevice ||
80 *attrs == common::CUDASubprogramAttrs::Device) {
81 return {};
82 }
83 }
84 }
85 } else if (x.GetSpecificIntrinsic()) {
86 // TODO(CUDA): Check for unsupported intrinsics here
87 return {};
88 }
89 return parser::MessageFormattedText(
90 "'%s' may not be called in device code"_err_en_US, x.GetName());
91 }
92};
93
94template <typename A> static MaybeMsg CheckUnwrappedExpr(const A &x) {
95 if (const auto *expr{parser::Unwrap<parser::Expr>(x)}) {
96 return DeviceExprChecker{}(expr->typedExpr);
97 }
98 return {};
99}
100
101template <typename A>
102static void CheckUnwrappedExpr(
103 SemanticsContext &context, SourceName at, const A &x) {
104 if (const auto *expr{parser::Unwrap<parser::Expr>(x)}) {
105 if (auto msg{DeviceExprChecker{}(expr->typedExpr)}) {
106 context.Say(at, std::move(*msg));
107 }
108 }
109}
110
111template <bool CUF_KERNEL> struct ActionStmtChecker {
112 template <typename A> static MaybeMsg WhyNotOk(const A &x) {
113 if constexpr (ConstraintTrait<A>) {
114 return WhyNotOk(x.thing);
115 } else if constexpr (WrapperTrait<A>) {
116 return WhyNotOk(x.v);
117 } else if constexpr (UnionTrait<A>) {
118 return WhyNotOk(x.u);
119 } else if constexpr (TupleTrait<A>) {
120 return WhyNotOk(x.t);
121 } else {
122 return parser::MessageFormattedText{
123 "Statement may not appear in device code"_err_en_US};
124 }
125 }
126 template <typename A>
127 static MaybeMsg WhyNotOk(const common::Indirection<A> &x) {
128 return WhyNotOk(x.value());
129 }
130 template <typename... As>
131 static MaybeMsg WhyNotOk(const std::variant<As...> &x) {
132 return common::visit([](const auto &x) { return WhyNotOk(x); }, x);
133 }
134 template <std::size_t J = 0, typename... As>
135 static MaybeMsg WhyNotOk(const std::tuple<As...> &x) {
136 if constexpr (J == sizeof...(As)) {
137 return {};
138 } else if (auto msg{WhyNotOk(std::get<J>(x))}) {
139 return msg;
140 } else {
141 return WhyNotOk<(J + 1)>(x);
142 }
143 }
144 template <typename A> static MaybeMsg WhyNotOk(const std::list<A> &x) {
145 for (const auto &y : x) {
146 if (MaybeMsg result{WhyNotOk(y)}) {
147 return result;
148 }
149 }
150 return {};
151 }
152 template <typename A> static MaybeMsg WhyNotOk(const std::optional<A> &x) {
153 if (x) {
154 return WhyNotOk(*x);
155 } else {
156 return {};
157 }
158 }
159 template <typename A>
160 static MaybeMsg WhyNotOk(const parser::UnlabeledStatement<A> &x) {
161 return WhyNotOk(x.statement);
162 }
163 template <typename A>
164 static MaybeMsg WhyNotOk(const parser::Statement<A> &x) {
165 return WhyNotOk(x.statement);
166 }
167 static MaybeMsg WhyNotOk(const parser::AllocateStmt &) {
168 return {}; // AllocateObjects are checked elsewhere
169 }
170 static MaybeMsg WhyNotOk(const parser::AllocateCoarraySpec &) {
171 return parser::MessageFormattedText(
172 "A coarray may not be allocated on the device"_err_en_US);
173 }
174 static MaybeMsg WhyNotOk(const parser::DeallocateStmt &) {
175 return {}; // AllocateObjects are checked elsewhere
176 }
177 static MaybeMsg WhyNotOk(const parser::AssignmentStmt &x) {
178 return DeviceExprChecker{}(x.typedAssignment);
179 }
180 static MaybeMsg WhyNotOk(const parser::CallStmt &x) {
181 return DeviceExprChecker{}(x.typedCall);
182 }
183 static MaybeMsg WhyNotOk(const parser::ContinueStmt &) { return {}; }
184 static MaybeMsg WhyNotOk(const parser::IfStmt &x) {
185 if (auto result{
186 CheckUnwrappedExpr(std::get<parser::ScalarLogicalExpr>(x.t))}) {
187 return result;
188 }
189 return WhyNotOk(
190 std::get<parser::UnlabeledStatement<parser::ActionStmt>>(x.t)
191 .statement);
192 }
193 static MaybeMsg WhyNotOk(const parser::NullifyStmt &x) {
194 for (const auto &y : x.v) {
195 if (MaybeMsg result{DeviceExprChecker{}(y.typedExpr)}) {
196 return result;
197 }
198 }
199 return {};
200 }
201 static MaybeMsg WhyNotOk(const parser::PointerAssignmentStmt &x) {
202 return DeviceExprChecker{}(x.typedAssignment);
203 }
204};
205
206template <bool IsCUFKernelDo> class DeviceContextChecker {
207public:
208 explicit DeviceContextChecker(SemanticsContext &c) : context_{c} {}
209 void CheckSubprogram(const parser::Name &name, const parser::Block &body) {
210 if (name.symbol) {
211 const auto *subp{
212 name.symbol->GetUltimate().detailsIf<SubprogramDetails>()};
213 if (subp && subp->moduleInterface()) {
214 subp = subp->moduleInterface()
215 ->GetUltimate()
216 .detailsIf<SubprogramDetails>();
217 }
218 if (subp &&
219 subp->cudaSubprogramAttrs().value_or(
220 common::CUDASubprogramAttrs::Host) !=
221 common::CUDASubprogramAttrs::Host) {
222 Check(body);
223 }
224 }
225 }
226 void Check(const parser::Block &block) {
227 for (const auto &epc : block) {
228 Check(epc);
229 }
230 }
231
232private:
233 void Check(const parser::ExecutionPartConstruct &epc) {
234 common::visit(
235 common::visitors{
236 [&](const parser::ExecutableConstruct &x) { Check(x); },
237 [&](const parser::Statement<common::Indirection<parser::EntryStmt>>
238 &x) {
239 context_.Say(x.source,
240 "Device code may not contain an ENTRY statement"_err_en_US);
241 },
242 [](const parser::Statement<common::Indirection<parser::FormatStmt>>
243 &) {},
244 [](const parser::Statement<common::Indirection<parser::DataStmt>>
245 &) {},
246 [](const parser::Statement<
247 common::Indirection<parser::NamelistStmt>> &) {},
248 [](const parser::ErrorRecovery &) {},
249 },
250 epc.u);
251 }
252 void Check(const parser::ExecutableConstruct &ec) {
253 common::visit(
254 common::visitors{
255 [&](const parser::Statement<parser::ActionStmt> &stmt) {
256 Check(stmt.statement, stmt.source);
257 },
258 [&](const common::Indirection<parser::DoConstruct> &x) {
259 if (const std::optional<parser::LoopControl> &control{
260 x.value().GetLoopControl()}) {
261 common::visit([&](const auto &y) { Check(y); }, control->u);
262 }
263 Check(std::get<parser::Block>(x.value().t));
264 },
265 [&](const common::Indirection<parser::BlockConstruct> &x) {
266 Check(std::get<parser::Block>(x.value().t));
267 },
268 [&](const common::Indirection<parser::IfConstruct> &x) {
269 Check(x.value());
270 },
271 [&](const auto &x) {
272 if (auto source{parser::GetSource(x)}) {
273 context_.Say(*source,
274 "Statement may not appear in device code"_err_en_US);
275 }
276 },
277 },
278 ec.u);
279 }
280 template <typename SEEK, typename A>
281 static const SEEK *GetIOControl(const A &stmt) {
282 for (const auto &spec : stmt.controls) {
283 if (const auto *result{std::get_if<SEEK>(&spec.u)}) {
284 return result;
285 }
286 }
287 return nullptr;
288 }
289 template <typename A> static bool IsInternalIO(const A &stmt) {
290 if (stmt.iounit.has_value()) {
291 return std::holds_alternative<Fortran::parser::Variable>(stmt.iounit->u);
292 }
293 if (auto *unit{GetIOControl<Fortran::parser::IoUnit>(stmt)}) {
294 return std::holds_alternative<Fortran::parser::Variable>(unit->u);
295 }
296 return false;
297 }
298 void WarnOnIoStmt(const parser::CharBlock &source) {
299 context_.Say(
300 source, "I/O statement might not be supported on device"_warn_en_US);
301 }
302 template <typename A>
303 void WarnIfNotInternal(const A &stmt, const parser::CharBlock &source) {
304 if (!IsInternalIO(stmt)) {
305 WarnOnIoStmt(source);
306 }
307 }
308 void Check(const parser::ActionStmt &stmt, const parser::CharBlock &source) {
309 common::visit(
310 common::visitors{
311 [&](const common::Indirection<parser::PrintStmt> &) {},
312 [&](const common::Indirection<parser::WriteStmt> &x) {
313 if (x.value().format) { // Formatted write to '*' or '6'
314 if (std::holds_alternative<Fortran::parser::Star>(
315 x.value().format->u)) {
316 if (x.value().iounit) {
317 if (std::holds_alternative<Fortran::parser::Star>(
318 x.value().iounit->u)) {
319 return;
320 }
321 }
322 }
323 }
324 WarnIfNotInternal(x.value(), source);
325 },
326 [&](const common::Indirection<parser::CloseStmt> &x) {
327 WarnOnIoStmt(source);
328 },
329 [&](const common::Indirection<parser::EndfileStmt> &x) {
330 WarnOnIoStmt(source);
331 },
332 [&](const common::Indirection<parser::OpenStmt> &x) {
333 WarnOnIoStmt(source);
334 },
335 [&](const common::Indirection<parser::ReadStmt> &x) {
336 WarnIfNotInternal(x.value(), source);
337 },
338 [&](const common::Indirection<parser::InquireStmt> &x) {
339 WarnOnIoStmt(source);
340 },
341 [&](const common::Indirection<parser::RewindStmt> &x) {
342 WarnOnIoStmt(source);
343 },
344 [&](const common::Indirection<parser::BackspaceStmt> &x) {
345 WarnOnIoStmt(source);
346 },
347 [&](const common::Indirection<parser::IfStmt> &x) {
348 Check(x.value());
349 },
350 [&](const auto &x) {
351 if (auto msg{ActionStmtChecker<IsCUFKernelDo>::WhyNotOk(x)}) {
352 context_.Say(source, std::move(*msg));
353 }
354 },
355 },
356 stmt.u);
357 }
358 void Check(const parser::IfConstruct &ic) {
359 const auto &ifS{std::get<parser::Statement<parser::IfThenStmt>>(ic.t)};
360 CheckUnwrappedExpr(context_, ifS.source,
361 std::get<parser::ScalarLogicalExpr>(ifS.statement.t));
362 Check(std::get<parser::Block>(ic.t));
363 for (const auto &eib :
364 std::get<std::list<parser::IfConstruct::ElseIfBlock>>(ic.t)) {
365 const auto &eIfS{std::get<parser::Statement<parser::ElseIfStmt>>(eib.t)};
366 CheckUnwrappedExpr(context_, eIfS.source,
367 std::get<parser::ScalarLogicalExpr>(eIfS.statement.t));
368 Check(std::get<parser::Block>(eib.t));
369 }
370 if (const auto &eb{
371 std::get<std::optional<parser::IfConstruct::ElseBlock>>(ic.t)}) {
372 Check(std::get<parser::Block>(eb->t));
373 }
374 }
375 void Check(const parser::IfStmt &is) {
376 const auto &uS{
377 std::get<parser::UnlabeledStatement<parser::ActionStmt>>(is.t)};
378 CheckUnwrappedExpr(
379 context_, uS.source, std::get<parser::ScalarLogicalExpr>(is.t));
380 Check(uS.statement, uS.source);
381 }
382 void Check(const parser::LoopControl::Bounds &bounds) {
383 Check(bounds.lower);
384 Check(bounds.upper);
385 if (bounds.step) {
386 Check(*bounds.step);
387 }
388 }
389 void Check(const parser::LoopControl::Concurrent &x) {
390 const auto &header{std::get<parser::ConcurrentHeader>(x.t)};
391 for (const auto &cc :
392 std::get<std::list<parser::ConcurrentControl>>(header.t)) {
393 Check(std::get<1>(cc.t));
394 Check(std::get<2>(cc.t));
395 if (const auto &step{
396 std::get<std::optional<parser::ScalarIntExpr>>(cc.t)}) {
397 Check(*step);
398 }
399 }
400 if (const auto &mask{
401 std::get<std::optional<parser::ScalarLogicalExpr>>(header.t)}) {
402 Check(*mask);
403 }
404 }
405 void Check(const parser::ScalarLogicalExpr &x) {
406 Check(DEREF(parser::Unwrap<parser::Expr>(x)));
407 }
408 void Check(const parser::ScalarIntExpr &x) {
409 Check(DEREF(parser::Unwrap<parser::Expr>(x)));
410 }
411 void Check(const parser::ScalarExpr &x) {
412 Check(DEREF(parser::Unwrap<parser::Expr>(x)));
413 }
414 void Check(const parser::Expr &expr) {
415 if (MaybeMsg msg{DeviceExprChecker{}(expr.typedExpr)}) {
416 context_.Say(expr.source, std::move(*msg));
417 }
418 }
419
420 SemanticsContext &context_;
421};
422
423void CUDAChecker::Enter(const parser::SubroutineSubprogram &x) {
424 DeviceContextChecker<false>{context_}.CheckSubprogram(
425 std::get<parser::Name>(
426 std::get<parser::Statement<parser::SubroutineStmt>>(x.t).statement.t),
427 std::get<parser::ExecutionPart>(x.t).v);
428}
429
430void CUDAChecker::Enter(const parser::FunctionSubprogram &x) {
431 DeviceContextChecker<false>{context_}.CheckSubprogram(
432 std::get<parser::Name>(
433 std::get<parser::Statement<parser::FunctionStmt>>(x.t).statement.t),
434 std::get<parser::ExecutionPart>(x.t).v);
435}
436
437void CUDAChecker::Enter(const parser::SeparateModuleSubprogram &x) {
438 DeviceContextChecker<false>{context_}.CheckSubprogram(
439 std::get<parser::Statement<parser::MpSubprogramStmt>>(x.t).statement.v,
440 std::get<parser::ExecutionPart>(x.t).v);
441}
442
443// !$CUF KERNEL DO semantic checks
444
445static int DoConstructTightNesting(
446 const parser::DoConstruct *doConstruct, const parser::Block *&innerBlock) {
447 if (!doConstruct || !doConstruct->IsDoNormal()) {
448 return 0;
449 }
450 innerBlock = &std::get<parser::Block>(doConstruct->t);
451 if (innerBlock->size() == 1) {
452 if (const auto *execConstruct{
453 std::get_if<parser::ExecutableConstruct>(&innerBlock->front().u)}) {
454 if (const auto *next{
455 std::get_if<common::Indirection<parser::DoConstruct>>(
456 &execConstruct->u)}) {
457 return 1 + DoConstructTightNesting(&next->value(), innerBlock);
458 }
459 }
460 }
461 return 1;
462}
463
464void CUDAChecker::Enter(const parser::CUFKernelDoConstruct &x) {
465 auto source{std::get<parser::CUFKernelDoConstruct::Directive>(x.t).source};
466 const auto &directive{std::get<parser::CUFKernelDoConstruct::Directive>(x.t)};
467 std::int64_t depth{1};
468 if (auto expr{AnalyzeExpr(context_,
469 std::get<std::optional<parser::ScalarIntConstantExpr>>(
470 directive.t))}) {
471 depth = evaluate::ToInt64(expr).value_or(0);
472 if (depth <= 0) {
473 context_.Say(source,
474 "!$CUF KERNEL DO (%jd): loop nesting depth must be positive"_err_en_US,
475 std::intmax_t{depth});
476 depth = 1;
477 }
478 }
479 const parser::DoConstruct *doConstruct{common::GetPtrFromOptional(
480 std::get<std::optional<parser::DoConstruct>>(x.t))};
481 const parser::Block *innerBlock{nullptr};
482 if (DoConstructTightNesting(doConstruct, innerBlock) < depth) {
483 context_.Say(source,
484 "!$CUF KERNEL DO (%jd) must be followed by a DO construct with tightly nested outer levels of counted DO loops"_err_en_US,
485 std::intmax_t{depth});
486 }
487 if (innerBlock) {
488 DeviceContextChecker<true>{context_}.Check(*innerBlock);
489 }
490}
491
492void CUDAChecker::Enter(const parser::AssignmentStmt &x) {
493 auto lhsLoc{std::get<parser::Variable>(x.t).GetSource()};
494 const auto &scope{context_.FindScope(lhsLoc)};
495 const Scope &progUnit{GetProgramUnitContaining(scope)};
496 if (IsCUDADeviceContext(&progUnit)) {
497 return; // Data transfer with assignment is only perform on host.
498 }
499
500 const evaluate::Assignment *assign{semantics::GetAssignment(x)};
501 if (!assign) {
502 return;
503 }
504
505 int nbLhs{evaluate::GetNbOfCUDASymbols(assign->lhs)};
506 int nbRhs{evaluate::GetNbOfCUDASymbols(assign->rhs)};
507
508 // device to host transfer with more than one device object on the rhs is not
509 // legal.
510 if (nbLhs == 0 && nbRhs > 1) {
511 context_.Say(lhsLoc,
512 "More than one reference to a CUDA object on the right hand side of the assigment"_err_en_US);
513 }
514}
515
516} // namespace Fortran::semantics
517

source code of flang/lib/Semantics/check-cuda.cpp