1 | //===-- lib/Semantics/check-cuda.cpp ----------------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "check-cuda.h" |
10 | #include "flang/Common/template.h" |
11 | #include "flang/Evaluate/fold.h" |
12 | #include "flang/Evaluate/tools.h" |
13 | #include "flang/Evaluate/traverse.h" |
14 | #include "flang/Parser/parse-tree-visitor.h" |
15 | #include "flang/Parser/parse-tree.h" |
16 | #include "flang/Parser/tools.h" |
17 | #include "flang/Semantics/expression.h" |
18 | #include "flang/Semantics/symbol.h" |
19 | #include "flang/Semantics/tools.h" |
20 | #include "llvm/ADT/StringSet.h" |
21 | |
22 | // Once labeled DO constructs have been canonicalized and their parse subtrees |
23 | // transformed into parser::DoConstructs, scan the parser::Blocks of the program |
24 | // and merge adjacent CUFKernelDoConstructs and DoConstructs whenever the |
25 | // CUFKernelDoConstruct doesn't already have an embedded DoConstruct. Also |
26 | // emit errors about improper or missing DoConstructs. |
27 | |
28 | namespace Fortran::parser { |
29 | struct Mutator { |
30 | template <typename A> bool Pre(A &) { return true; } |
31 | template <typename A> void Post(A &) {} |
32 | bool Pre(Block &); |
33 | }; |
34 | |
35 | bool Mutator::Pre(Block &block) { |
36 | for (auto iter{block.begin()}; iter != block.end(); ++iter) { |
37 | if (auto *kernel{Unwrap<CUFKernelDoConstruct>(*iter)}) { |
38 | auto &nested{std::get<std::optional<DoConstruct>>(kernel->t)}; |
39 | if (!nested) { |
40 | if (auto next{iter}; ++next != block.end()) { |
41 | if (auto *doConstruct{Unwrap<DoConstruct>(*next)}) { |
42 | nested = std::move(*doConstruct); |
43 | block.erase(next); |
44 | } |
45 | } |
46 | } |
47 | } else { |
48 | Walk(*iter, *this); |
49 | } |
50 | } |
51 | return false; |
52 | } |
53 | } // namespace Fortran::parser |
54 | |
55 | namespace Fortran::semantics { |
56 | |
57 | bool CanonicalizeCUDA(parser::Program &program) { |
58 | parser::Mutator mutator; |
59 | parser::Walk(program, mutator); |
60 | return true; |
61 | } |
62 | |
63 | using MaybeMsg = std::optional<parser::MessageFormattedText>; |
64 | |
65 | static const llvm::StringSet<> warpFunctions_ = {"match_all_syncjj" , |
66 | "match_all_syncjx" , "match_all_syncjf" , "match_all_syncjd" , |
67 | "match_any_syncjj" , "match_any_syncjx" , "match_any_syncjf" , |
68 | "match_any_syncjd" }; |
69 | |
70 | // Traverses an evaluate::Expr<> in search of unsupported operations |
71 | // on the device. |
72 | |
73 | struct DeviceExprChecker |
74 | : public evaluate::AnyTraverse<DeviceExprChecker, MaybeMsg> { |
75 | using Result = MaybeMsg; |
76 | using Base = evaluate::AnyTraverse<DeviceExprChecker, Result>; |
77 | explicit DeviceExprChecker(SemanticsContext &c) : Base(*this), context_{c} {} |
78 | using Base::operator(); |
79 | Result operator()(const evaluate::ProcedureDesignator &x) const { |
80 | if (const Symbol * sym{x.GetInterfaceSymbol()}) { |
81 | const auto *subp{ |
82 | sym->GetUltimate().detailsIf<semantics::SubprogramDetails>()}; |
83 | if (subp) { |
84 | if (auto attrs{subp->cudaSubprogramAttrs()}) { |
85 | if (*attrs == common::CUDASubprogramAttrs::HostDevice || |
86 | *attrs == common::CUDASubprogramAttrs::Device) { |
87 | if (warpFunctions_.contains(sym->name().ToString()) && |
88 | !context_.languageFeatures().IsEnabled( |
89 | Fortran::common::LanguageFeature::CudaWarpMatchFunction)) { |
90 | return parser::MessageFormattedText( |
91 | "warp match function disabled"_err_en_US ); |
92 | } |
93 | return {}; |
94 | } |
95 | } |
96 | } |
97 | |
98 | const Symbol &ultimate{sym->GetUltimate()}; |
99 | const Scope &scope{ultimate.owner()}; |
100 | const Symbol *mod{scope.IsModule() ? scope.symbol() : nullptr}; |
101 | // Allow ieee_arithmetic module functions to be called on the device. |
102 | // TODO: Check for unsupported ieee_arithmetic on the device. |
103 | if (mod && mod->name() == "ieee_arithmetic" ) { |
104 | return {}; |
105 | } |
106 | } else if (x.GetSpecificIntrinsic()) { |
107 | // TODO(CUDA): Check for unsupported intrinsics here |
108 | return {}; |
109 | } |
110 | |
111 | return parser::MessageFormattedText( |
112 | "'%s' may not be called in device code"_err_en_US , x.GetName()); |
113 | } |
114 | |
115 | SemanticsContext &context_; |
116 | }; |
117 | |
118 | struct FindHostArray |
119 | : public evaluate::AnyTraverse<FindHostArray, const Symbol *> { |
120 | using Result = const Symbol *; |
121 | using Base = evaluate::AnyTraverse<FindHostArray, Result>; |
122 | FindHostArray() : Base(*this) {} |
123 | using Base::operator(); |
124 | Result operator()(const evaluate::Component &x) const { |
125 | const Symbol &symbol{x.GetLastSymbol()}; |
126 | if (IsAllocatableOrPointer(symbol)) { |
127 | if (Result hostArray{(*this)(symbol)}) { |
128 | return hostArray; |
129 | } |
130 | } |
131 | return (*this)(x.base()); |
132 | } |
133 | Result operator()(const Symbol &symbol) const { |
134 | if (const auto *details{ |
135 | symbol.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()}) { |
136 | if (details->IsArray() && |
137 | !symbol.attrs().test(Fortran::semantics::Attr::PARAMETER) && |
138 | (!details->cudaDataAttr() || |
139 | (details->cudaDataAttr() && |
140 | *details->cudaDataAttr() != common::CUDADataAttr::Device && |
141 | *details->cudaDataAttr() != common::CUDADataAttr::Constant && |
142 | *details->cudaDataAttr() != common::CUDADataAttr::Managed && |
143 | *details->cudaDataAttr() != common::CUDADataAttr::Shared && |
144 | *details->cudaDataAttr() != common::CUDADataAttr::Unified))) { |
145 | return &symbol; |
146 | } |
147 | } |
148 | return nullptr; |
149 | } |
150 | }; |
151 | |
152 | template <typename A> |
153 | static MaybeMsg CheckUnwrappedExpr(SemanticsContext &context, const A &x) { |
154 | if (const auto *expr{parser::Unwrap<parser::Expr>(x)}) { |
155 | return DeviceExprChecker{context}(expr->typedExpr); |
156 | } |
157 | return {}; |
158 | } |
159 | |
160 | template <typename A> |
161 | static void CheckUnwrappedExpr( |
162 | SemanticsContext &context, SourceName at, const A &x) { |
163 | if (const auto *expr{parser::Unwrap<parser::Expr>(x)}) { |
164 | if (auto msg{DeviceExprChecker{context}(expr->typedExpr)}) { |
165 | context.Say(at, std::move(*msg)); |
166 | } |
167 | } |
168 | } |
169 | |
170 | template <bool CUF_KERNEL> struct ActionStmtChecker { |
171 | template <typename A> |
172 | static MaybeMsg WhyNotOk(SemanticsContext &context, const A &x) { |
173 | if constexpr (ConstraintTrait<A>) { |
174 | return WhyNotOk(context, x.thing); |
175 | } else if constexpr (WrapperTrait<A>) { |
176 | return WhyNotOk(context, x.v); |
177 | } else if constexpr (UnionTrait<A>) { |
178 | return WhyNotOk(context, x.u); |
179 | } else if constexpr (TupleTrait<A>) { |
180 | return WhyNotOk(context, x.t); |
181 | } else { |
182 | return parser::MessageFormattedText{ |
183 | "Statement may not appear in device code"_err_en_US }; |
184 | } |
185 | } |
186 | template <typename A> |
187 | static MaybeMsg WhyNotOk( |
188 | SemanticsContext &context, const common::Indirection<A> &x) { |
189 | return WhyNotOk(context, x.value()); |
190 | } |
191 | template <typename... As> |
192 | static MaybeMsg WhyNotOk( |
193 | SemanticsContext &context, const std::variant<As...> &x) { |
194 | return common::visit( |
195 | [&context](const auto &x) { return WhyNotOk(context, x); }, x); |
196 | } |
197 | template <std::size_t J = 0, typename... As> |
198 | static MaybeMsg WhyNotOk( |
199 | SemanticsContext &context, const std::tuple<As...> &x) { |
200 | if constexpr (J == sizeof...(As)) { |
201 | return {}; |
202 | } else if (auto msg{WhyNotOk(context, std::get<J>(x))}) { |
203 | return msg; |
204 | } else { |
205 | return WhyNotOk<(J + 1)>(context, x); |
206 | } |
207 | } |
208 | template <typename A> |
209 | static MaybeMsg WhyNotOk(SemanticsContext &context, const std::list<A> &x) { |
210 | for (const auto &y : x) { |
211 | if (MaybeMsg result{WhyNotOk(context, y)}) { |
212 | return result; |
213 | } |
214 | } |
215 | return {}; |
216 | } |
217 | template <typename A> |
218 | static MaybeMsg WhyNotOk( |
219 | SemanticsContext &context, const std::optional<A> &x) { |
220 | if (x) { |
221 | return WhyNotOk(context, *x); |
222 | } else { |
223 | return {}; |
224 | } |
225 | } |
226 | template <typename A> |
227 | static MaybeMsg WhyNotOk( |
228 | SemanticsContext &context, const parser::UnlabeledStatement<A> &x) { |
229 | return WhyNotOk(context, x.statement); |
230 | } |
231 | template <typename A> |
232 | static MaybeMsg WhyNotOk( |
233 | SemanticsContext &context, const parser::Statement<A> &x) { |
234 | return WhyNotOk(context, x.statement); |
235 | } |
236 | static MaybeMsg WhyNotOk( |
237 | SemanticsContext &context, const parser::AllocateStmt &) { |
238 | return {}; // AllocateObjects are checked elsewhere |
239 | } |
240 | static MaybeMsg WhyNotOk( |
241 | SemanticsContext &context, const parser::AllocateCoarraySpec &) { |
242 | return parser::MessageFormattedText( |
243 | "A coarray may not be allocated on the device"_err_en_US ); |
244 | } |
245 | static MaybeMsg WhyNotOk( |
246 | SemanticsContext &context, const parser::DeallocateStmt &) { |
247 | return {}; // AllocateObjects are checked elsewhere |
248 | } |
249 | static MaybeMsg WhyNotOk( |
250 | SemanticsContext &context, const parser::AssignmentStmt &x) { |
251 | return DeviceExprChecker{context}(x.typedAssignment); |
252 | } |
253 | static MaybeMsg WhyNotOk( |
254 | SemanticsContext &context, const parser::CallStmt &x) { |
255 | return DeviceExprChecker{context}(x.typedCall); |
256 | } |
257 | static MaybeMsg WhyNotOk( |
258 | SemanticsContext &context, const parser::ContinueStmt &) { |
259 | return {}; |
260 | } |
261 | static MaybeMsg WhyNotOk(SemanticsContext &context, const parser::IfStmt &x) { |
262 | if (auto result{CheckUnwrappedExpr( |
263 | context, std::get<parser::ScalarLogicalExpr>(x.t))}) { |
264 | return result; |
265 | } |
266 | return WhyNotOk(context, |
267 | std::get<parser::UnlabeledStatement<parser::ActionStmt>>(x.t) |
268 | .statement); |
269 | } |
270 | static MaybeMsg WhyNotOk( |
271 | SemanticsContext &context, const parser::NullifyStmt &x) { |
272 | for (const auto &y : x.v) { |
273 | if (MaybeMsg result{DeviceExprChecker{context}(y.typedExpr)}) { |
274 | return result; |
275 | } |
276 | } |
277 | return {}; |
278 | } |
279 | static MaybeMsg WhyNotOk( |
280 | SemanticsContext &context, const parser::PointerAssignmentStmt &x) { |
281 | return DeviceExprChecker{context}(x.typedAssignment); |
282 | } |
283 | }; |
284 | |
285 | template <bool IsCUFKernelDo> class DeviceContextChecker { |
286 | public: |
287 | explicit DeviceContextChecker(SemanticsContext &c) : context_{c} {} |
288 | void CheckSubprogram(const parser::Name &name, const parser::Block &body) { |
289 | if (name.symbol) { |
290 | const auto *subp{ |
291 | name.symbol->GetUltimate().detailsIf<SubprogramDetails>()}; |
292 | if (subp && subp->moduleInterface()) { |
293 | subp = subp->moduleInterface() |
294 | ->GetUltimate() |
295 | .detailsIf<SubprogramDetails>(); |
296 | } |
297 | if (subp && |
298 | subp->cudaSubprogramAttrs().value_or( |
299 | common::CUDASubprogramAttrs::Host) != |
300 | common::CUDASubprogramAttrs::Host) { |
301 | isHostDevice = subp->cudaSubprogramAttrs() && |
302 | subp->cudaSubprogramAttrs() == |
303 | common::CUDASubprogramAttrs::HostDevice; |
304 | Check(body); |
305 | } |
306 | } |
307 | } |
308 | void Check(const parser::Block &block) { |
309 | for (const auto &epc : block) { |
310 | Check(epc); |
311 | } |
312 | } |
313 | |
314 | private: |
315 | void Check(const parser::ExecutionPartConstruct &epc) { |
316 | common::visit( |
317 | common::visitors{ |
318 | [&](const parser::ExecutableConstruct &x) { Check(x); }, |
319 | [&](const parser::Statement<common::Indirection<parser::EntryStmt>> |
320 | &x) { |
321 | context_.Say(x.source, |
322 | "Device code may not contain an ENTRY statement"_err_en_US ); |
323 | }, |
324 | [](const parser::Statement<common::Indirection<parser::FormatStmt>> |
325 | &) {}, |
326 | [](const parser::Statement<common::Indirection<parser::DataStmt>> |
327 | &) {}, |
328 | [](const parser::Statement< |
329 | common::Indirection<parser::NamelistStmt>> &) {}, |
330 | [](const parser::ErrorRecovery &) {}, |
331 | }, |
332 | epc.u); |
333 | } |
334 | void Check(const parser::ExecutableConstruct &ec) { |
335 | common::visit( |
336 | common::visitors{ |
337 | [&](const parser::Statement<parser::ActionStmt> &stmt) { |
338 | Check(stmt.statement, stmt.source); |
339 | }, |
340 | [&](const common::Indirection<parser::DoConstruct> &x) { |
341 | if (const std::optional<parser::LoopControl> &control{ |
342 | x.value().GetLoopControl()}) { |
343 | common::visit([&](const auto &y) { Check(y); }, control->u); |
344 | } |
345 | Check(std::get<parser::Block>(x.value().t)); |
346 | }, |
347 | [&](const common::Indirection<parser::BlockConstruct> &x) { |
348 | Check(std::get<parser::Block>(x.value().t)); |
349 | }, |
350 | [&](const common::Indirection<parser::IfConstruct> &x) { |
351 | Check(x.value()); |
352 | }, |
353 | [&](const common::Indirection<parser::CaseConstruct> &x) { |
354 | const auto &caseList{ |
355 | std::get<std::list<parser::CaseConstruct::Case>>( |
356 | x.value().t)}; |
357 | for (const parser::CaseConstruct::Case &c : caseList) { |
358 | Check(std::get<parser::Block>(c.t)); |
359 | } |
360 | }, |
361 | [&](const common::Indirection<parser::CompilerDirective> &x) { |
362 | // TODO(CUDA): Check for unsupported compiler directive here. |
363 | }, |
364 | [&](const auto &x) { |
365 | if (auto source{parser::GetSource(x)}) { |
366 | context_.Say(*source, |
367 | "Statement may not appear in device code"_err_en_US ); |
368 | } |
369 | }, |
370 | }, |
371 | ec.u); |
372 | } |
373 | template <typename SEEK, typename A> |
374 | static const SEEK *GetIOControl(const A &stmt) { |
375 | for (const auto &spec : stmt.controls) { |
376 | if (const auto *result{std::get_if<SEEK>(&spec.u)}) { |
377 | return result; |
378 | } |
379 | } |
380 | return nullptr; |
381 | } |
382 | template <typename A> static bool IsInternalIO(const A &stmt) { |
383 | if (stmt.iounit.has_value()) { |
384 | return std::holds_alternative<Fortran::parser::Variable>(stmt.iounit->u); |
385 | } |
386 | if (auto *unit{GetIOControl<Fortran::parser::IoUnit>(stmt)}) { |
387 | return std::holds_alternative<Fortran::parser::Variable>(unit->u); |
388 | } |
389 | return false; |
390 | } |
391 | void WarnOnIoStmt(const parser::CharBlock &source) { |
392 | context_.Warn(common::UsageWarning::CUDAUsage, source, |
393 | "I/O statement might not be supported on device"_warn_en_US ); |
394 | } |
395 | template <typename A> |
396 | void WarnIfNotInternal(const A &stmt, const parser::CharBlock &source) { |
397 | if (!IsInternalIO(stmt)) { |
398 | WarnOnIoStmt(source); |
399 | } |
400 | } |
401 | template <typename A> |
402 | void ErrorIfHostSymbol(const A &expr, parser::CharBlock source) { |
403 | if (isHostDevice) |
404 | return; |
405 | if (const Symbol * hostArray{FindHostArray{}(expr)}) { |
406 | context_.Say(source, |
407 | "Host array '%s' cannot be present in device context"_err_en_US , |
408 | hostArray->name()); |
409 | } |
410 | } |
411 | void ErrorInCUFKernel(parser::CharBlock source) { |
412 | if (IsCUFKernelDo) { |
413 | context_.Say( |
414 | source, "Statement may not appear in cuf kernel code"_err_en_US ); |
415 | } |
416 | } |
417 | void Check(const parser::ActionStmt &stmt, const parser::CharBlock &source) { |
418 | common::visit( |
419 | common::visitors{ |
420 | [&](const common::Indirection<parser::CycleStmt> &) { |
421 | ErrorInCUFKernel(source); |
422 | }, |
423 | [&](const common::Indirection<parser::ExitStmt> &) { |
424 | ErrorInCUFKernel(source); |
425 | }, |
426 | [&](const common::Indirection<parser::GotoStmt> &) { |
427 | ErrorInCUFKernel(source); |
428 | }, |
429 | [&](const common::Indirection<parser::StopStmt> &) { return; }, |
430 | [&](const common::Indirection<parser::PrintStmt> &) {}, |
431 | [&](const common::Indirection<parser::WriteStmt> &x) { |
432 | if (x.value().format) { // Formatted write to '*' or '6' |
433 | if (std::holds_alternative<Fortran::parser::Star>( |
434 | x.value().format->u)) { |
435 | if (x.value().iounit) { |
436 | if (std::holds_alternative<Fortran::parser::Star>( |
437 | x.value().iounit->u)) { |
438 | return; |
439 | } |
440 | } |
441 | } |
442 | } |
443 | WarnIfNotInternal(x.value(), source); |
444 | }, |
445 | [&](const common::Indirection<parser::CloseStmt> &x) { |
446 | WarnOnIoStmt(source); |
447 | }, |
448 | [&](const common::Indirection<parser::EndfileStmt> &x) { |
449 | WarnOnIoStmt(source); |
450 | }, |
451 | [&](const common::Indirection<parser::OpenStmt> &x) { |
452 | WarnOnIoStmt(source); |
453 | }, |
454 | [&](const common::Indirection<parser::ReadStmt> &x) { |
455 | WarnIfNotInternal(x.value(), source); |
456 | }, |
457 | [&](const common::Indirection<parser::InquireStmt> &x) { |
458 | WarnOnIoStmt(source); |
459 | }, |
460 | [&](const common::Indirection<parser::RewindStmt> &x) { |
461 | WarnOnIoStmt(source); |
462 | }, |
463 | [&](const common::Indirection<parser::BackspaceStmt> &x) { |
464 | WarnOnIoStmt(source); |
465 | }, |
466 | [&](const common::Indirection<parser::IfStmt> &x) { |
467 | Check(x.value()); |
468 | }, |
469 | [&](const common::Indirection<parser::AssignmentStmt> &x) { |
470 | if (const evaluate::Assignment * |
471 | assign{semantics::GetAssignment(x.value())}) { |
472 | ErrorIfHostSymbol(assign->lhs, source); |
473 | ErrorIfHostSymbol(assign->rhs, source); |
474 | } |
475 | if (auto msg{ActionStmtChecker<IsCUFKernelDo>::WhyNotOk( |
476 | context_, x)}) { |
477 | context_.Say(source, std::move(*msg)); |
478 | } |
479 | }, |
480 | [&](const auto &x) { |
481 | if (auto msg{ActionStmtChecker<IsCUFKernelDo>::WhyNotOk( |
482 | context_, x)}) { |
483 | context_.Say(source, std::move(*msg)); |
484 | } |
485 | }, |
486 | }, |
487 | stmt.u); |
488 | } |
489 | void Check(const parser::IfConstruct &ic) { |
490 | const auto &ifS{std::get<parser::Statement<parser::IfThenStmt>>(ic.t)}; |
491 | CheckUnwrappedExpr(context_, ifS.source, |
492 | std::get<parser::ScalarLogicalExpr>(ifS.statement.t)); |
493 | Check(std::get<parser::Block>(ic.t)); |
494 | for (const auto &eib : |
495 | std::get<std::list<parser::IfConstruct::ElseIfBlock>>(ic.t)) { |
496 | const auto &eIfS{std::get<parser::Statement<parser::ElseIfStmt>>(eib.t)}; |
497 | CheckUnwrappedExpr(context_, eIfS.source, |
498 | std::get<parser::ScalarLogicalExpr>(eIfS.statement.t)); |
499 | Check(std::get<parser::Block>(eib.t)); |
500 | } |
501 | if (const auto &eb{ |
502 | std::get<std::optional<parser::IfConstruct::ElseBlock>>(ic.t)}) { |
503 | Check(std::get<parser::Block>(eb->t)); |
504 | } |
505 | } |
506 | void Check(const parser::IfStmt &is) { |
507 | const auto &uS{ |
508 | std::get<parser::UnlabeledStatement<parser::ActionStmt>>(is.t)}; |
509 | CheckUnwrappedExpr( |
510 | context_, uS.source, std::get<parser::ScalarLogicalExpr>(is.t)); |
511 | Check(uS.statement, uS.source); |
512 | } |
513 | void Check(const parser::LoopControl::Bounds &bounds) { |
514 | Check(bounds.lower); |
515 | Check(bounds.upper); |
516 | if (bounds.step) { |
517 | Check(*bounds.step); |
518 | } |
519 | } |
520 | void Check(const parser::LoopControl::Concurrent &x) { |
521 | const auto &{std::get<parser::ConcurrentHeader>(x.t)}; |
522 | for (const auto &cc : |
523 | std::get<std::list<parser::ConcurrentControl>>(header.t)) { |
524 | Check(std::get<1>(cc.t)); |
525 | Check(std::get<2>(cc.t)); |
526 | if (const auto &step{ |
527 | std::get<std::optional<parser::ScalarIntExpr>>(cc.t)}) { |
528 | Check(*step); |
529 | } |
530 | } |
531 | if (const auto &mask{ |
532 | std::get<std::optional<parser::ScalarLogicalExpr>>(header.t)}) { |
533 | Check(*mask); |
534 | } |
535 | } |
536 | void Check(const parser::ScalarLogicalExpr &x) { |
537 | Check(DEREF(parser::Unwrap<parser::Expr>(x))); |
538 | } |
539 | void Check(const parser::ScalarIntExpr &x) { |
540 | Check(DEREF(parser::Unwrap<parser::Expr>(x))); |
541 | } |
542 | void Check(const parser::ScalarExpr &x) { |
543 | Check(DEREF(parser::Unwrap<parser::Expr>(x))); |
544 | } |
545 | void Check(const parser::Expr &expr) { |
546 | if (MaybeMsg msg{DeviceExprChecker{context_}(expr.typedExpr)}) { |
547 | context_.Say(expr.source, std::move(*msg)); |
548 | } |
549 | } |
550 | |
551 | SemanticsContext &context_; |
552 | bool isHostDevice{false}; |
553 | }; |
554 | |
555 | void CUDAChecker::Enter(const parser::SubroutineSubprogram &x) { |
556 | DeviceContextChecker<false>{context_}.CheckSubprogram( |
557 | std::get<parser::Name>( |
558 | std::get<parser::Statement<parser::SubroutineStmt>>(x.t).statement.t), |
559 | std::get<parser::ExecutionPart>(x.t).v); |
560 | } |
561 | |
562 | void CUDAChecker::Enter(const parser::FunctionSubprogram &x) { |
563 | DeviceContextChecker<false>{context_}.CheckSubprogram( |
564 | std::get<parser::Name>( |
565 | std::get<parser::Statement<parser::FunctionStmt>>(x.t).statement.t), |
566 | std::get<parser::ExecutionPart>(x.t).v); |
567 | } |
568 | |
569 | void CUDAChecker::Enter(const parser::SeparateModuleSubprogram &x) { |
570 | DeviceContextChecker<false>{context_}.CheckSubprogram( |
571 | std::get<parser::Statement<parser::MpSubprogramStmt>>(x.t).statement.v, |
572 | std::get<parser::ExecutionPart>(x.t).v); |
573 | } |
574 | |
575 | // !$CUF KERNEL DO semantic checks |
576 | |
577 | static int DoConstructTightNesting( |
578 | const parser::DoConstruct *doConstruct, const parser::Block *&innerBlock) { |
579 | if (!doConstruct || |
580 | (!doConstruct->IsDoNormal() && !doConstruct->IsDoConcurrent())) { |
581 | return 0; |
582 | } |
583 | innerBlock = &std::get<parser::Block>(doConstruct->t); |
584 | if (doConstruct->IsDoConcurrent()) { |
585 | const auto &loopControl = doConstruct->GetLoopControl(); |
586 | if (loopControl) { |
587 | if (const auto *concurrentControl{ |
588 | std::get_if<parser::LoopControl::Concurrent>(&loopControl->u)}) { |
589 | const auto & = |
590 | std::get<Fortran::parser::ConcurrentHeader>(concurrentControl->t); |
591 | const auto &controls = |
592 | std::get<std::list<Fortran::parser::ConcurrentControl>>( |
593 | concurrentHeader.t); |
594 | return controls.size(); |
595 | } |
596 | } |
597 | return 0; |
598 | } |
599 | if (innerBlock->size() == 1) { |
600 | if (const auto *execConstruct{ |
601 | std::get_if<parser::ExecutableConstruct>(&innerBlock->front().u)}) { |
602 | if (const auto *next{ |
603 | std::get_if<common::Indirection<parser::DoConstruct>>( |
604 | &execConstruct->u)}) { |
605 | return 1 + DoConstructTightNesting(&next->value(), innerBlock); |
606 | } |
607 | } |
608 | } |
609 | return 1; |
610 | } |
611 | |
612 | static void CheckReduce( |
613 | SemanticsContext &context, const parser::CUFReduction &reduce) { |
614 | auto op{std::get<parser::CUFReduction::Operator>(reduce.t).v}; |
615 | for (const auto &var : |
616 | std::get<std::list<parser::Scalar<parser::Variable>>>(reduce.t)) { |
617 | if (const auto &typedExprPtr{var.thing.typedExpr}; |
618 | typedExprPtr && typedExprPtr->v) { |
619 | const auto &expr{*typedExprPtr->v}; |
620 | if (auto type{expr.GetType()}) { |
621 | auto cat{type->category()}; |
622 | bool isOk{false}; |
623 | switch (op) { |
624 | case parser::ReductionOperator::Operator::Plus: |
625 | case parser::ReductionOperator::Operator::Multiply: |
626 | case parser::ReductionOperator::Operator::Max: |
627 | case parser::ReductionOperator::Operator::Min: |
628 | isOk = cat == TypeCategory::Integer || cat == TypeCategory::Real || |
629 | cat == TypeCategory::Complex; |
630 | break; |
631 | case parser::ReductionOperator::Operator::Iand: |
632 | case parser::ReductionOperator::Operator::Ior: |
633 | case parser::ReductionOperator::Operator::Ieor: |
634 | isOk = cat == TypeCategory::Integer; |
635 | break; |
636 | case parser::ReductionOperator::Operator::And: |
637 | case parser::ReductionOperator::Operator::Or: |
638 | case parser::ReductionOperator::Operator::Eqv: |
639 | case parser::ReductionOperator::Operator::Neqv: |
640 | isOk = cat == TypeCategory::Logical; |
641 | break; |
642 | } |
643 | if (!isOk) { |
644 | context.Say(var.thing.GetSource(), |
645 | "!$CUF KERNEL DO REDUCE operation is not acceptable for a variable with type %s"_err_en_US , |
646 | type->AsFortran()); |
647 | } |
648 | } |
649 | } |
650 | } |
651 | } |
652 | |
653 | void CUDAChecker::Enter(const parser::CUFKernelDoConstruct &x) { |
654 | auto source{std::get<parser::CUFKernelDoConstruct::Directive>(x.t).source}; |
655 | const auto &directive{std::get<parser::CUFKernelDoConstruct::Directive>(x.t)}; |
656 | std::int64_t depth{1}; |
657 | if (auto expr{AnalyzeExpr(context_, |
658 | std::get<std::optional<parser::ScalarIntConstantExpr>>( |
659 | directive.t))}) { |
660 | depth = evaluate::ToInt64(expr).value_or(0); |
661 | if (depth <= 0) { |
662 | context_.Say(source, |
663 | "!$CUF KERNEL DO (%jd): loop nesting depth must be positive"_err_en_US , |
664 | std::intmax_t{depth}); |
665 | depth = 1; |
666 | } |
667 | } |
668 | const parser::DoConstruct *doConstruct{common::GetPtrFromOptional( |
669 | std::get<std::optional<parser::DoConstruct>>(x.t))}; |
670 | const parser::Block *innerBlock{nullptr}; |
671 | if (DoConstructTightNesting(doConstruct, innerBlock) < depth) { |
672 | if (doConstruct && doConstruct->IsDoConcurrent()) |
673 | context_.Say(source, |
674 | "!$CUF KERNEL DO (%jd) must be followed by a DO CONCURRENT construct with at least %jd indices"_err_en_US , |
675 | std::intmax_t{depth}, std::intmax_t{depth}); |
676 | else |
677 | context_.Say(source, |
678 | "!$CUF KERNEL DO (%jd) must be followed by a DO construct with tightly nested outer levels of counted DO loops"_err_en_US , |
679 | std::intmax_t{depth}); |
680 | } |
681 | if (innerBlock) { |
682 | DeviceContextChecker<true>{context_}.Check(*innerBlock); |
683 | } |
684 | for (const auto &reduce : |
685 | std::get<std::list<parser::CUFReduction>>(directive.t)) { |
686 | CheckReduce(context_, reduce); |
687 | } |
688 | inCUFKernelDoConstruct_ = true; |
689 | } |
690 | |
691 | void CUDAChecker::Leave(const parser::CUFKernelDoConstruct &) { |
692 | inCUFKernelDoConstruct_ = false; |
693 | } |
694 | |
695 | void CUDAChecker::Enter(const parser::AssignmentStmt &x) { |
696 | auto lhsLoc{std::get<parser::Variable>(x.t).GetSource()}; |
697 | const auto &scope{context_.FindScope(lhsLoc)}; |
698 | const Scope &progUnit{GetProgramUnitContaining(scope)}; |
699 | if (IsCUDADeviceContext(&progUnit) || inCUFKernelDoConstruct_) { |
700 | return; // Data transfer with assignment is only perform on host. |
701 | } |
702 | |
703 | const evaluate::Assignment *assign{semantics::GetAssignment(x)}; |
704 | if (!assign) { |
705 | return; |
706 | } |
707 | |
708 | int nbLhs{evaluate::GetNbOfCUDADeviceSymbols(assign->lhs)}; |
709 | int nbRhs{evaluate::GetNbOfCUDADeviceSymbols(assign->rhs)}; |
710 | |
711 | // device to host transfer with more than one device object on the rhs is not |
712 | // legal. |
713 | if (nbLhs == 0 && nbRhs > 1) { |
714 | context_.Say(lhsLoc, |
715 | "More than one reference to a CUDA object on the right hand side of the assigment"_err_en_US ); |
716 | } |
717 | } |
718 | |
719 | } // namespace Fortran::semantics |
720 | |