1 | //===-- lib/Semantics/check-cuda.cpp ----------------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "check-cuda.h" |
10 | #include "flang/Common/template.h" |
11 | #include "flang/Evaluate/fold.h" |
12 | #include "flang/Evaluate/tools.h" |
13 | #include "flang/Evaluate/traverse.h" |
14 | #include "flang/Parser/parse-tree-visitor.h" |
15 | #include "flang/Parser/parse-tree.h" |
16 | #include "flang/Parser/tools.h" |
17 | #include "flang/Semantics/expression.h" |
18 | #include "flang/Semantics/symbol.h" |
19 | #include "flang/Semantics/tools.h" |
20 | |
21 | // Once labeled DO constructs have been canonicalized and their parse subtrees |
22 | // transformed into parser::DoConstructs, scan the parser::Blocks of the program |
23 | // and merge adjacent CUFKernelDoConstructs and DoConstructs whenever the |
24 | // CUFKernelDoConstruct doesn't already have an embedded DoConstruct. Also |
25 | // emit errors about improper or missing DoConstructs. |
26 | |
27 | namespace Fortran::parser { |
28 | struct Mutator { |
29 | template <typename A> bool Pre(A &) { return true; } |
30 | template <typename A> void Post(A &) {} |
31 | bool Pre(Block &); |
32 | }; |
33 | |
34 | bool Mutator::Pre(Block &block) { |
35 | for (auto iter{block.begin()}; iter != block.end(); ++iter) { |
36 | if (auto *kernel{Unwrap<CUFKernelDoConstruct>(*iter)}) { |
37 | auto &nested{std::get<std::optional<DoConstruct>>(kernel->t)}; |
38 | if (!nested) { |
39 | if (auto next{iter}; ++next != block.end()) { |
40 | if (auto *doConstruct{Unwrap<DoConstruct>(*next)}) { |
41 | nested = std::move(*doConstruct); |
42 | block.erase(next); |
43 | } |
44 | } |
45 | } |
46 | } else { |
47 | Walk(*iter, *this); |
48 | } |
49 | } |
50 | return false; |
51 | } |
52 | } // namespace Fortran::parser |
53 | |
54 | namespace Fortran::semantics { |
55 | |
56 | bool CanonicalizeCUDA(parser::Program &program) { |
57 | parser::Mutator mutator; |
58 | parser::Walk(program, mutator); |
59 | return true; |
60 | } |
61 | |
62 | using MaybeMsg = std::optional<parser::MessageFormattedText>; |
63 | |
64 | // Traverses an evaluate::Expr<> in search of unsupported operations |
65 | // on the device. |
66 | |
67 | struct DeviceExprChecker |
68 | : public evaluate::AnyTraverse<DeviceExprChecker, MaybeMsg> { |
69 | using Result = MaybeMsg; |
70 | using Base = evaluate::AnyTraverse<DeviceExprChecker, Result>; |
71 | DeviceExprChecker() : Base(*this) {} |
72 | using Base::operator(); |
73 | Result operator()(const evaluate::ProcedureDesignator &x) const { |
74 | if (const Symbol * sym{x.GetInterfaceSymbol()}) { |
75 | const auto *subp{ |
76 | sym->GetUltimate().detailsIf<semantics::SubprogramDetails>()}; |
77 | if (subp) { |
78 | if (auto attrs{subp->cudaSubprogramAttrs()}) { |
79 | if (*attrs == common::CUDASubprogramAttrs::HostDevice || |
80 | *attrs == common::CUDASubprogramAttrs::Device) { |
81 | return {}; |
82 | } |
83 | } |
84 | } |
85 | } else if (x.GetSpecificIntrinsic()) { |
86 | // TODO(CUDA): Check for unsupported intrinsics here |
87 | return {}; |
88 | } |
89 | return parser::MessageFormattedText( |
90 | "'%s' may not be called in device code"_err_en_US , x.GetName()); |
91 | } |
92 | }; |
93 | |
94 | template <typename A> static MaybeMsg CheckUnwrappedExpr(const A &x) { |
95 | if (const auto *expr{parser::Unwrap<parser::Expr>(x)}) { |
96 | return DeviceExprChecker{}(expr->typedExpr); |
97 | } |
98 | return {}; |
99 | } |
100 | |
101 | template <typename A> |
102 | static void CheckUnwrappedExpr( |
103 | SemanticsContext &context, SourceName at, const A &x) { |
104 | if (const auto *expr{parser::Unwrap<parser::Expr>(x)}) { |
105 | if (auto msg{DeviceExprChecker{}(expr->typedExpr)}) { |
106 | context.Say(at, std::move(*msg)); |
107 | } |
108 | } |
109 | } |
110 | |
111 | template <bool CUF_KERNEL> struct ActionStmtChecker { |
112 | template <typename A> static MaybeMsg WhyNotOk(const A &x) { |
113 | if constexpr (ConstraintTrait<A>) { |
114 | return WhyNotOk(x.thing); |
115 | } else if constexpr (WrapperTrait<A>) { |
116 | return WhyNotOk(x.v); |
117 | } else if constexpr (UnionTrait<A>) { |
118 | return WhyNotOk(x.u); |
119 | } else if constexpr (TupleTrait<A>) { |
120 | return WhyNotOk(x.t); |
121 | } else { |
122 | return parser::MessageFormattedText{ |
123 | "Statement may not appear in device code"_err_en_US }; |
124 | } |
125 | } |
126 | template <typename A> |
127 | static MaybeMsg WhyNotOk(const common::Indirection<A> &x) { |
128 | return WhyNotOk(x.value()); |
129 | } |
130 | template <typename... As> |
131 | static MaybeMsg WhyNotOk(const std::variant<As...> &x) { |
132 | return common::visit([](const auto &x) { return WhyNotOk(x); }, x); |
133 | } |
134 | template <std::size_t J = 0, typename... As> |
135 | static MaybeMsg WhyNotOk(const std::tuple<As...> &x) { |
136 | if constexpr (J == sizeof...(As)) { |
137 | return {}; |
138 | } else if (auto msg{WhyNotOk(std::get<J>(x))}) { |
139 | return msg; |
140 | } else { |
141 | return WhyNotOk<(J + 1)>(x); |
142 | } |
143 | } |
144 | template <typename A> static MaybeMsg WhyNotOk(const std::list<A> &x) { |
145 | for (const auto &y : x) { |
146 | if (MaybeMsg result{WhyNotOk(y)}) { |
147 | return result; |
148 | } |
149 | } |
150 | return {}; |
151 | } |
152 | template <typename A> static MaybeMsg WhyNotOk(const std::optional<A> &x) { |
153 | if (x) { |
154 | return WhyNotOk(*x); |
155 | } else { |
156 | return {}; |
157 | } |
158 | } |
159 | template <typename A> |
160 | static MaybeMsg WhyNotOk(const parser::UnlabeledStatement<A> &x) { |
161 | return WhyNotOk(x.statement); |
162 | } |
163 | template <typename A> |
164 | static MaybeMsg WhyNotOk(const parser::Statement<A> &x) { |
165 | return WhyNotOk(x.statement); |
166 | } |
167 | static MaybeMsg WhyNotOk(const parser::AllocateStmt &) { |
168 | return {}; // AllocateObjects are checked elsewhere |
169 | } |
170 | static MaybeMsg WhyNotOk(const parser::AllocateCoarraySpec &) { |
171 | return parser::MessageFormattedText( |
172 | "A coarray may not be allocated on the device"_err_en_US ); |
173 | } |
174 | static MaybeMsg WhyNotOk(const parser::DeallocateStmt &) { |
175 | return {}; // AllocateObjects are checked elsewhere |
176 | } |
177 | static MaybeMsg WhyNotOk(const parser::AssignmentStmt &x) { |
178 | return DeviceExprChecker{}(x.typedAssignment); |
179 | } |
180 | static MaybeMsg WhyNotOk(const parser::CallStmt &x) { |
181 | return DeviceExprChecker{}(x.typedCall); |
182 | } |
183 | static MaybeMsg WhyNotOk(const parser::ContinueStmt &) { return {}; } |
184 | static MaybeMsg WhyNotOk(const parser::IfStmt &x) { |
185 | if (auto result{ |
186 | CheckUnwrappedExpr(std::get<parser::ScalarLogicalExpr>(x.t))}) { |
187 | return result; |
188 | } |
189 | return WhyNotOk( |
190 | std::get<parser::UnlabeledStatement<parser::ActionStmt>>(x.t) |
191 | .statement); |
192 | } |
193 | static MaybeMsg WhyNotOk(const parser::NullifyStmt &x) { |
194 | for (const auto &y : x.v) { |
195 | if (MaybeMsg result{DeviceExprChecker{}(y.typedExpr)}) { |
196 | return result; |
197 | } |
198 | } |
199 | return {}; |
200 | } |
201 | static MaybeMsg WhyNotOk(const parser::PointerAssignmentStmt &x) { |
202 | return DeviceExprChecker{}(x.typedAssignment); |
203 | } |
204 | }; |
205 | |
206 | template <bool IsCUFKernelDo> class DeviceContextChecker { |
207 | public: |
208 | explicit DeviceContextChecker(SemanticsContext &c) : context_{c} {} |
209 | void CheckSubprogram(const parser::Name &name, const parser::Block &body) { |
210 | if (name.symbol) { |
211 | const auto *subp{ |
212 | name.symbol->GetUltimate().detailsIf<SubprogramDetails>()}; |
213 | if (subp && subp->moduleInterface()) { |
214 | subp = subp->moduleInterface() |
215 | ->GetUltimate() |
216 | .detailsIf<SubprogramDetails>(); |
217 | } |
218 | if (subp && |
219 | subp->cudaSubprogramAttrs().value_or( |
220 | common::CUDASubprogramAttrs::Host) != |
221 | common::CUDASubprogramAttrs::Host) { |
222 | Check(body); |
223 | } |
224 | } |
225 | } |
226 | void Check(const parser::Block &block) { |
227 | for (const auto &epc : block) { |
228 | Check(epc); |
229 | } |
230 | } |
231 | |
232 | private: |
233 | void Check(const parser::ExecutionPartConstruct &epc) { |
234 | common::visit( |
235 | common::visitors{ |
236 | [&](const parser::ExecutableConstruct &x) { Check(x); }, |
237 | [&](const parser::Statement<common::Indirection<parser::EntryStmt>> |
238 | &x) { |
239 | context_.Say(x.source, |
240 | "Device code may not contain an ENTRY statement"_err_en_US ); |
241 | }, |
242 | [](const parser::Statement<common::Indirection<parser::FormatStmt>> |
243 | &) {}, |
244 | [](const parser::Statement<common::Indirection<parser::DataStmt>> |
245 | &) {}, |
246 | [](const parser::Statement< |
247 | common::Indirection<parser::NamelistStmt>> &) {}, |
248 | [](const parser::ErrorRecovery &) {}, |
249 | }, |
250 | epc.u); |
251 | } |
252 | void Check(const parser::ExecutableConstruct &ec) { |
253 | common::visit( |
254 | common::visitors{ |
255 | [&](const parser::Statement<parser::ActionStmt> &stmt) { |
256 | Check(stmt.statement, stmt.source); |
257 | }, |
258 | [&](const common::Indirection<parser::DoConstruct> &x) { |
259 | if (const std::optional<parser::LoopControl> &control{ |
260 | x.value().GetLoopControl()}) { |
261 | common::visit([&](const auto &y) { Check(y); }, control->u); |
262 | } |
263 | Check(std::get<parser::Block>(x.value().t)); |
264 | }, |
265 | [&](const common::Indirection<parser::BlockConstruct> &x) { |
266 | Check(std::get<parser::Block>(x.value().t)); |
267 | }, |
268 | [&](const common::Indirection<parser::IfConstruct> &x) { |
269 | Check(x.value()); |
270 | }, |
271 | [&](const auto &x) { |
272 | if (auto source{parser::GetSource(x)}) { |
273 | context_.Say(*source, |
274 | "Statement may not appear in device code"_err_en_US ); |
275 | } |
276 | }, |
277 | }, |
278 | ec.u); |
279 | } |
280 | template <typename SEEK, typename A> |
281 | static const SEEK *GetIOControl(const A &stmt) { |
282 | for (const auto &spec : stmt.controls) { |
283 | if (const auto *result{std::get_if<SEEK>(&spec.u)}) { |
284 | return result; |
285 | } |
286 | } |
287 | return nullptr; |
288 | } |
289 | template <typename A> static bool IsInternalIO(const A &stmt) { |
290 | if (stmt.iounit.has_value()) { |
291 | return std::holds_alternative<Fortran::parser::Variable>(stmt.iounit->u); |
292 | } |
293 | if (auto *unit{GetIOControl<Fortran::parser::IoUnit>(stmt)}) { |
294 | return std::holds_alternative<Fortran::parser::Variable>(unit->u); |
295 | } |
296 | return false; |
297 | } |
298 | void WarnOnIoStmt(const parser::CharBlock &source) { |
299 | context_.Say( |
300 | source, "I/O statement might not be supported on device"_warn_en_US ); |
301 | } |
302 | template <typename A> |
303 | void WarnIfNotInternal(const A &stmt, const parser::CharBlock &source) { |
304 | if (!IsInternalIO(stmt)) { |
305 | WarnOnIoStmt(source); |
306 | } |
307 | } |
308 | void Check(const parser::ActionStmt &stmt, const parser::CharBlock &source) { |
309 | common::visit( |
310 | common::visitors{ |
311 | [&](const common::Indirection<parser::PrintStmt> &) {}, |
312 | [&](const common::Indirection<parser::WriteStmt> &x) { |
313 | if (x.value().format) { // Formatted write to '*' or '6' |
314 | if (std::holds_alternative<Fortran::parser::Star>( |
315 | x.value().format->u)) { |
316 | if (x.value().iounit) { |
317 | if (std::holds_alternative<Fortran::parser::Star>( |
318 | x.value().iounit->u)) { |
319 | return; |
320 | } |
321 | } |
322 | } |
323 | } |
324 | WarnIfNotInternal(x.value(), source); |
325 | }, |
326 | [&](const common::Indirection<parser::CloseStmt> &x) { |
327 | WarnOnIoStmt(source); |
328 | }, |
329 | [&](const common::Indirection<parser::EndfileStmt> &x) { |
330 | WarnOnIoStmt(source); |
331 | }, |
332 | [&](const common::Indirection<parser::OpenStmt> &x) { |
333 | WarnOnIoStmt(source); |
334 | }, |
335 | [&](const common::Indirection<parser::ReadStmt> &x) { |
336 | WarnIfNotInternal(x.value(), source); |
337 | }, |
338 | [&](const common::Indirection<parser::InquireStmt> &x) { |
339 | WarnOnIoStmt(source); |
340 | }, |
341 | [&](const common::Indirection<parser::RewindStmt> &x) { |
342 | WarnOnIoStmt(source); |
343 | }, |
344 | [&](const common::Indirection<parser::BackspaceStmt> &x) { |
345 | WarnOnIoStmt(source); |
346 | }, |
347 | [&](const common::Indirection<parser::IfStmt> &x) { |
348 | Check(x.value()); |
349 | }, |
350 | [&](const auto &x) { |
351 | if (auto msg{ActionStmtChecker<IsCUFKernelDo>::WhyNotOk(x)}) { |
352 | context_.Say(source, std::move(*msg)); |
353 | } |
354 | }, |
355 | }, |
356 | stmt.u); |
357 | } |
358 | void Check(const parser::IfConstruct &ic) { |
359 | const auto &ifS{std::get<parser::Statement<parser::IfThenStmt>>(ic.t)}; |
360 | CheckUnwrappedExpr(context_, ifS.source, |
361 | std::get<parser::ScalarLogicalExpr>(ifS.statement.t)); |
362 | Check(std::get<parser::Block>(ic.t)); |
363 | for (const auto &eib : |
364 | std::get<std::list<parser::IfConstruct::ElseIfBlock>>(ic.t)) { |
365 | const auto &eIfS{std::get<parser::Statement<parser::ElseIfStmt>>(eib.t)}; |
366 | CheckUnwrappedExpr(context_, eIfS.source, |
367 | std::get<parser::ScalarLogicalExpr>(eIfS.statement.t)); |
368 | Check(std::get<parser::Block>(eib.t)); |
369 | } |
370 | if (const auto &eb{ |
371 | std::get<std::optional<parser::IfConstruct::ElseBlock>>(ic.t)}) { |
372 | Check(std::get<parser::Block>(eb->t)); |
373 | } |
374 | } |
375 | void Check(const parser::IfStmt &is) { |
376 | const auto &uS{ |
377 | std::get<parser::UnlabeledStatement<parser::ActionStmt>>(is.t)}; |
378 | CheckUnwrappedExpr( |
379 | context_, uS.source, std::get<parser::ScalarLogicalExpr>(is.t)); |
380 | Check(uS.statement, uS.source); |
381 | } |
382 | void Check(const parser::LoopControl::Bounds &bounds) { |
383 | Check(bounds.lower); |
384 | Check(bounds.upper); |
385 | if (bounds.step) { |
386 | Check(*bounds.step); |
387 | } |
388 | } |
389 | void Check(const parser::LoopControl::Concurrent &x) { |
390 | const auto &{std::get<parser::ConcurrentHeader>(x.t)}; |
391 | for (const auto &cc : |
392 | std::get<std::list<parser::ConcurrentControl>>(header.t)) { |
393 | Check(std::get<1>(cc.t)); |
394 | Check(std::get<2>(cc.t)); |
395 | if (const auto &step{ |
396 | std::get<std::optional<parser::ScalarIntExpr>>(cc.t)}) { |
397 | Check(*step); |
398 | } |
399 | } |
400 | if (const auto &mask{ |
401 | std::get<std::optional<parser::ScalarLogicalExpr>>(header.t)}) { |
402 | Check(*mask); |
403 | } |
404 | } |
405 | void Check(const parser::ScalarLogicalExpr &x) { |
406 | Check(DEREF(parser::Unwrap<parser::Expr>(x))); |
407 | } |
408 | void Check(const parser::ScalarIntExpr &x) { |
409 | Check(DEREF(parser::Unwrap<parser::Expr>(x))); |
410 | } |
411 | void Check(const parser::ScalarExpr &x) { |
412 | Check(DEREF(parser::Unwrap<parser::Expr>(x))); |
413 | } |
414 | void Check(const parser::Expr &expr) { |
415 | if (MaybeMsg msg{DeviceExprChecker{}(expr.typedExpr)}) { |
416 | context_.Say(expr.source, std::move(*msg)); |
417 | } |
418 | } |
419 | |
420 | SemanticsContext &context_; |
421 | }; |
422 | |
423 | void CUDAChecker::Enter(const parser::SubroutineSubprogram &x) { |
424 | DeviceContextChecker<false>{context_}.CheckSubprogram( |
425 | std::get<parser::Name>( |
426 | std::get<parser::Statement<parser::SubroutineStmt>>(x.t).statement.t), |
427 | std::get<parser::ExecutionPart>(x.t).v); |
428 | } |
429 | |
430 | void CUDAChecker::Enter(const parser::FunctionSubprogram &x) { |
431 | DeviceContextChecker<false>{context_}.CheckSubprogram( |
432 | std::get<parser::Name>( |
433 | std::get<parser::Statement<parser::FunctionStmt>>(x.t).statement.t), |
434 | std::get<parser::ExecutionPart>(x.t).v); |
435 | } |
436 | |
437 | void CUDAChecker::Enter(const parser::SeparateModuleSubprogram &x) { |
438 | DeviceContextChecker<false>{context_}.CheckSubprogram( |
439 | std::get<parser::Statement<parser::MpSubprogramStmt>>(x.t).statement.v, |
440 | std::get<parser::ExecutionPart>(x.t).v); |
441 | } |
442 | |
443 | // !$CUF KERNEL DO semantic checks |
444 | |
445 | static int DoConstructTightNesting( |
446 | const parser::DoConstruct *doConstruct, const parser::Block *&innerBlock) { |
447 | if (!doConstruct || !doConstruct->IsDoNormal()) { |
448 | return 0; |
449 | } |
450 | innerBlock = &std::get<parser::Block>(doConstruct->t); |
451 | if (innerBlock->size() == 1) { |
452 | if (const auto *execConstruct{ |
453 | std::get_if<parser::ExecutableConstruct>(&innerBlock->front().u)}) { |
454 | if (const auto *next{ |
455 | std::get_if<common::Indirection<parser::DoConstruct>>( |
456 | &execConstruct->u)}) { |
457 | return 1 + DoConstructTightNesting(&next->value(), innerBlock); |
458 | } |
459 | } |
460 | } |
461 | return 1; |
462 | } |
463 | |
464 | void CUDAChecker::Enter(const parser::CUFKernelDoConstruct &x) { |
465 | auto source{std::get<parser::CUFKernelDoConstruct::Directive>(x.t).source}; |
466 | const auto &directive{std::get<parser::CUFKernelDoConstruct::Directive>(x.t)}; |
467 | std::int64_t depth{1}; |
468 | if (auto expr{AnalyzeExpr(context_, |
469 | std::get<std::optional<parser::ScalarIntConstantExpr>>( |
470 | directive.t))}) { |
471 | depth = evaluate::ToInt64(expr).value_or(0); |
472 | if (depth <= 0) { |
473 | context_.Say(source, |
474 | "!$CUF KERNEL DO (%jd): loop nesting depth must be positive"_err_en_US , |
475 | std::intmax_t{depth}); |
476 | depth = 1; |
477 | } |
478 | } |
479 | const parser::DoConstruct *doConstruct{common::GetPtrFromOptional( |
480 | std::get<std::optional<parser::DoConstruct>>(x.t))}; |
481 | const parser::Block *innerBlock{nullptr}; |
482 | if (DoConstructTightNesting(doConstruct, innerBlock) < depth) { |
483 | context_.Say(source, |
484 | "!$CUF KERNEL DO (%jd) must be followed by a DO construct with tightly nested outer levels of counted DO loops"_err_en_US , |
485 | std::intmax_t{depth}); |
486 | } |
487 | if (innerBlock) { |
488 | DeviceContextChecker<true>{context_}.Check(*innerBlock); |
489 | } |
490 | } |
491 | |
492 | void CUDAChecker::Enter(const parser::AssignmentStmt &x) { |
493 | auto lhsLoc{std::get<parser::Variable>(x.t).GetSource()}; |
494 | const auto &scope{context_.FindScope(lhsLoc)}; |
495 | const Scope &progUnit{GetProgramUnitContaining(scope)}; |
496 | if (IsCUDADeviceContext(&progUnit)) { |
497 | return; // Data transfer with assignment is only perform on host. |
498 | } |
499 | |
500 | const evaluate::Assignment *assign{semantics::GetAssignment(x)}; |
501 | if (!assign) { |
502 | return; |
503 | } |
504 | |
505 | int nbLhs{evaluate::GetNbOfCUDASymbols(assign->lhs)}; |
506 | int nbRhs{evaluate::GetNbOfCUDASymbols(assign->rhs)}; |
507 | |
508 | // device to host transfer with more than one device object on the rhs is not |
509 | // legal. |
510 | if (nbLhs == 0 && nbRhs > 1) { |
511 | context_.Say(lhsLoc, |
512 | "More than one reference to a CUDA object on the right hand side of the assigment"_err_en_US ); |
513 | } |
514 | } |
515 | |
516 | } // namespace Fortran::semantics |
517 | |