1//===-- Atomic.cpp -- Lowering of atomic constructs -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "Atomic.h"
10#include "flang/Evaluate/expression.h"
11#include "flang/Evaluate/fold.h"
12#include "flang/Evaluate/tools.h"
13#include "flang/Evaluate/traverse.h"
14#include "flang/Evaluate/type.h"
15#include "flang/Lower/AbstractConverter.h"
16#include "flang/Lower/OpenMP/Clauses.h"
17#include "flang/Lower/PFTBuilder.h"
18#include "flang/Lower/StatementContext.h"
19#include "flang/Lower/SymbolMap.h"
20#include "flang/Optimizer/Builder/FIRBuilder.h"
21#include "flang/Optimizer/Builder/Todo.h"
22#include "flang/Parser/parse-tree.h"
23#include "flang/Semantics/semantics.h"
24#include "flang/Semantics/type.h"
25#include "flang/Support/Fortran.h"
26#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
27#include "llvm/ADT/STLExtras.h"
28#include "llvm/Support/CommandLine.h"
29#include "llvm/Support/raw_ostream.h"
30
31#include <optional>
32#include <string>
33#include <type_traits>
34#include <variant>
35#include <vector>
36
37static llvm::cl::opt<bool> DumpAtomicAnalysis("fdebug-dump-atomic-analysis");
38
39using namespace Fortran;
40
41// Don't import the entire Fortran::lower.
42namespace omp {
43using namespace Fortran::lower::omp;
44}
45
46namespace {
47// An example of a type that can be used to get the return value from
48// the visitor:
49// visitor(type_identity<Xyz>) -> result_type
50using SomeArgType = evaluate::Type<common::TypeCategory::Integer, 4>;
51
52struct GetProc
53 : public evaluate::Traverse<GetProc, const evaluate::ProcedureDesignator *,
54 false> {
55 using Result = const evaluate::ProcedureDesignator *;
56 using Base = evaluate::Traverse<GetProc, Result, false>;
57 GetProc() : Base(*this) {}
58
59 using Base::operator();
60
61 static Result Default() { return nullptr; }
62
63 Result operator()(const evaluate::ProcedureDesignator &p) const { return &p; }
64 static Result Combine(Result a, Result b) { return a != nullptr ? a : b; }
65};
66
67struct WithType {
68 WithType(const evaluate::DynamicType &t) : type(t) {
69 assert(type.category() != common::TypeCategory::Derived &&
70 "Type cannot be a derived type");
71 }
72
73 template <typename VisitorTy> //
74 auto visit(VisitorTy &&visitor) const
75 -> std::invoke_result_t<VisitorTy, SomeArgType> {
76 switch (type.category()) {
77 case common::TypeCategory::Integer:
78 switch (type.kind()) {
79 case 1:
80 return visitor(llvm::type_identity<evaluate::Type<Integer, 1>>{});
81 case 2:
82 return visitor(llvm::type_identity<evaluate::Type<Integer, 2>>{});
83 case 4:
84 return visitor(llvm::type_identity<evaluate::Type<Integer, 4>>{});
85 case 8:
86 return visitor(llvm::type_identity<evaluate::Type<Integer, 8>>{});
87 case 16:
88 return visitor(llvm::type_identity<evaluate::Type<Integer, 16>>{});
89 }
90 break;
91 case common::TypeCategory::Unsigned:
92 switch (type.kind()) {
93 case 1:
94 return visitor(llvm::type_identity<evaluate::Type<Unsigned, 1>>{});
95 case 2:
96 return visitor(llvm::type_identity<evaluate::Type<Unsigned, 2>>{});
97 case 4:
98 return visitor(llvm::type_identity<evaluate::Type<Unsigned, 4>>{});
99 case 8:
100 return visitor(llvm::type_identity<evaluate::Type<Unsigned, 8>>{});
101 case 16:
102 return visitor(llvm::type_identity<evaluate::Type<Unsigned, 16>>{});
103 }
104 break;
105 case common::TypeCategory::Real:
106 switch (type.kind()) {
107 case 2:
108 return visitor(llvm::type_identity<evaluate::Type<Real, 2>>{});
109 case 3:
110 return visitor(llvm::type_identity<evaluate::Type<Real, 3>>{});
111 case 4:
112 return visitor(llvm::type_identity<evaluate::Type<Real, 4>>{});
113 case 8:
114 return visitor(llvm::type_identity<evaluate::Type<Real, 8>>{});
115 case 10:
116 return visitor(llvm::type_identity<evaluate::Type<Real, 10>>{});
117 case 16:
118 return visitor(llvm::type_identity<evaluate::Type<Real, 16>>{});
119 }
120 break;
121 case common::TypeCategory::Complex:
122 switch (type.kind()) {
123 case 2:
124 return visitor(llvm::type_identity<evaluate::Type<Complex, 2>>{});
125 case 3:
126 return visitor(llvm::type_identity<evaluate::Type<Complex, 3>>{});
127 case 4:
128 return visitor(llvm::type_identity<evaluate::Type<Complex, 4>>{});
129 case 8:
130 return visitor(llvm::type_identity<evaluate::Type<Complex, 8>>{});
131 case 10:
132 return visitor(llvm::type_identity<evaluate::Type<Complex, 10>>{});
133 case 16:
134 return visitor(llvm::type_identity<evaluate::Type<Complex, 16>>{});
135 }
136 break;
137 case common::TypeCategory::Logical:
138 switch (type.kind()) {
139 case 1:
140 return visitor(llvm::type_identity<evaluate::Type<Logical, 1>>{});
141 case 2:
142 return visitor(llvm::type_identity<evaluate::Type<Logical, 2>>{});
143 case 4:
144 return visitor(llvm::type_identity<evaluate::Type<Logical, 4>>{});
145 case 8:
146 return visitor(llvm::type_identity<evaluate::Type<Logical, 8>>{});
147 }
148 break;
149 case common::TypeCategory::Character:
150 switch (type.kind()) {
151 case 1:
152 return visitor(llvm::type_identity<evaluate::Type<Character, 1>>{});
153 case 2:
154 return visitor(llvm::type_identity<evaluate::Type<Character, 2>>{});
155 case 4:
156 return visitor(llvm::type_identity<evaluate::Type<Character, 4>>{});
157 }
158 break;
159 case common::TypeCategory::Derived:
160 (void)Derived;
161 break;
162 }
163 llvm_unreachable("Unhandled type");
164 }
165
166 const evaluate::DynamicType &type;
167
168private:
169 // Shorter names.
170 static constexpr auto Character = common::TypeCategory::Character;
171 static constexpr auto Complex = common::TypeCategory::Complex;
172 static constexpr auto Derived = common::TypeCategory::Derived;
173 static constexpr auto Integer = common::TypeCategory::Integer;
174 static constexpr auto Logical = common::TypeCategory::Logical;
175 static constexpr auto Real = common::TypeCategory::Real;
176 static constexpr auto Unsigned = common::TypeCategory::Unsigned;
177};
178
179template <typename T, typename U = std::remove_const_t<T>>
180U AsRvalue(T &t) {
181 U copy{t};
182 return std::move(copy);
183}
184
185template <typename T>
186T &&AsRvalue(T &&t) {
187 return std::move(t);
188}
189
190struct ArgumentReplacer
191 : public evaluate::Traverse<ArgumentReplacer, bool, false> {
192 using Base = evaluate::Traverse<ArgumentReplacer, bool, false>;
193 using Result = bool;
194
195 Result Default() const { return false; }
196
197 ArgumentReplacer(evaluate::ActualArguments &&newArgs)
198 : Base(*this), args_(std::move(newArgs)) {}
199
200 using Base::operator();
201
202 template <typename T>
203 Result operator()(const evaluate::FunctionRef<T> &x) {
204 assert(!done_);
205 auto &mut = const_cast<evaluate::FunctionRef<T> &>(x);
206 mut.arguments() = args_;
207 done_ = true;
208 return true;
209 }
210
211 Result Combine(Result &&a, Result &&b) { return a || b; }
212
213private:
214 bool done_{false};
215 evaluate::ActualArguments &&args_;
216};
217} // namespace
218
219[[maybe_unused]] static void
220dumpAtomicAnalysis(const parser::OpenMPAtomicConstruct::Analysis &analysis) {
221 auto whatStr = [](int k) {
222 std::string txt = "?";
223 switch (k & parser::OpenMPAtomicConstruct::Analysis::Action) {
224 case parser::OpenMPAtomicConstruct::Analysis::None:
225 txt = "None";
226 break;
227 case parser::OpenMPAtomicConstruct::Analysis::Read:
228 txt = "Read";
229 break;
230 case parser::OpenMPAtomicConstruct::Analysis::Write:
231 txt = "Write";
232 break;
233 case parser::OpenMPAtomicConstruct::Analysis::Update:
234 txt = "Update";
235 break;
236 }
237 switch (k & parser::OpenMPAtomicConstruct::Analysis::Condition) {
238 case parser::OpenMPAtomicConstruct::Analysis::IfTrue:
239 txt += " | IfTrue";
240 break;
241 case parser::OpenMPAtomicConstruct::Analysis::IfFalse:
242 txt += " | IfFalse";
243 break;
244 }
245 return txt;
246 };
247
248 auto exprStr = [&](const parser::TypedExpr &expr) {
249 if (auto *maybe = expr.get()) {
250 if (maybe->v)
251 return maybe->v->AsFortran();
252 }
253 return "<null>"s;
254 };
255 auto assignStr = [&](const parser::AssignmentStmt::TypedAssignment &assign) {
256 if (auto *maybe = assign.get(); maybe && maybe->v) {
257 std::string str;
258 llvm::raw_string_ostream os(str);
259 maybe->v->AsFortran(os);
260 return str;
261 }
262 return "<null>"s;
263 };
264
265 const semantics::SomeExpr &atom = *analysis.atom.get()->v;
266
267 llvm::errs() << "Analysis {\n";
268 llvm::errs() << " atom: " << atom.AsFortran() << "\n";
269 llvm::errs() << " cond: " << exprStr(analysis.cond) << "\n";
270 llvm::errs() << " op0 {\n";
271 llvm::errs() << " what: " << whatStr(analysis.op0.what) << "\n";
272 llvm::errs() << " assign: " << assignStr(analysis.op0.assign) << "\n";
273 llvm::errs() << " }\n";
274 llvm::errs() << " op1 {\n";
275 llvm::errs() << " what: " << whatStr(analysis.op1.what) << "\n";
276 llvm::errs() << " assign: " << assignStr(analysis.op1.assign) << "\n";
277 llvm::errs() << " }\n";
278 llvm::errs() << "}\n";
279}
280
281static bool isPointerAssignment(const evaluate::Assignment &assign) {
282 return common::visit(
283 common::visitors{
284 [](const evaluate::Assignment::BoundsSpec &) { return true; },
285 [](const evaluate::Assignment::BoundsRemapping &) { return true; },
286 [](const auto &) { return false; },
287 },
288 assign.u);
289}
290
291static fir::FirOpBuilder::InsertPoint
292getInsertionPointBefore(mlir::Operation *op) {
293 return fir::FirOpBuilder::InsertPoint(op->getBlock(),
294 mlir::Block::iterator(op));
295}
296
297static fir::FirOpBuilder::InsertPoint
298getInsertionPointAfter(mlir::Operation *op) {
299 return fir::FirOpBuilder::InsertPoint(op->getBlock(),
300 ++mlir::Block::iterator(op));
301}
302
303static mlir::IntegerAttr getAtomicHint(lower::AbstractConverter &converter,
304 const omp::List<omp::Clause> &clauses) {
305 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
306 for (const omp::Clause &clause : clauses) {
307 if (clause.id != llvm::omp::Clause::OMPC_hint)
308 continue;
309 auto &hint = std::get<omp::clause::Hint>(clause.u);
310 auto maybeVal = evaluate::ToInt64(hint.v);
311 CHECK(maybeVal);
312 return builder.getI64IntegerAttr(*maybeVal);
313 }
314 return nullptr;
315}
316
317static mlir::omp::ClauseMemoryOrderKind
318getMemoryOrderKind(common::OmpMemoryOrderType kind) {
319 switch (kind) {
320 case common::OmpMemoryOrderType::Acq_Rel:
321 return mlir::omp::ClauseMemoryOrderKind::Acq_rel;
322 case common::OmpMemoryOrderType::Acquire:
323 return mlir::omp::ClauseMemoryOrderKind::Acquire;
324 case common::OmpMemoryOrderType::Relaxed:
325 return mlir::omp::ClauseMemoryOrderKind::Relaxed;
326 case common::OmpMemoryOrderType::Release:
327 return mlir::omp::ClauseMemoryOrderKind::Release;
328 case common::OmpMemoryOrderType::Seq_Cst:
329 return mlir::omp::ClauseMemoryOrderKind::Seq_cst;
330 }
331 llvm_unreachable("Unexpected kind");
332}
333
334static std::optional<mlir::omp::ClauseMemoryOrderKind>
335getMemoryOrderKind(llvm::omp::Clause clauseId) {
336 switch (clauseId) {
337 case llvm::omp::Clause::OMPC_acq_rel:
338 return mlir::omp::ClauseMemoryOrderKind::Acq_rel;
339 case llvm::omp::Clause::OMPC_acquire:
340 return mlir::omp::ClauseMemoryOrderKind::Acquire;
341 case llvm::omp::Clause::OMPC_relaxed:
342 return mlir::omp::ClauseMemoryOrderKind::Relaxed;
343 case llvm::omp::Clause::OMPC_release:
344 return mlir::omp::ClauseMemoryOrderKind::Release;
345 case llvm::omp::Clause::OMPC_seq_cst:
346 return mlir::omp::ClauseMemoryOrderKind::Seq_cst;
347 default:
348 return std::nullopt;
349 }
350}
351
352static std::optional<mlir::omp::ClauseMemoryOrderKind>
353getMemoryOrderFromRequires(const semantics::Scope &scope) {
354 // The REQUIRES construct is only allowed in the main program scope
355 // and module scope, but seems like we also accept it in a subprogram
356 // scope.
357 // For safety, traverse all enclosing scopes and check if their symbol
358 // contains REQUIRES.
359 for (const auto *sc{&scope}; sc->kind() != semantics::Scope::Kind::Global;
360 sc = &sc->parent()) {
361 const semantics::Symbol *sym = sc->symbol();
362 if (!sym)
363 continue;
364
365 const common::OmpMemoryOrderType *admo = common::visit(
366 [](auto &&s) {
367 using WithOmpDeclarative = semantics::WithOmpDeclarative;
368 if constexpr (std::is_convertible_v<decltype(s),
369 const WithOmpDeclarative &>) {
370 return s.ompAtomicDefaultMemOrder();
371 }
372 return static_cast<const common::OmpMemoryOrderType *>(nullptr);
373 },
374 sym->details());
375 if (admo)
376 return getMemoryOrderKind(*admo);
377 }
378
379 return std::nullopt;
380}
381
382static std::optional<mlir::omp::ClauseMemoryOrderKind>
383getDefaultAtomicMemOrder(semantics::SemanticsContext &semaCtx) {
384 unsigned version = semaCtx.langOptions().OpenMPVersion;
385 if (version > 50)
386 return mlir::omp::ClauseMemoryOrderKind::Relaxed;
387 return std::nullopt;
388}
389
390static std::optional<mlir::omp::ClauseMemoryOrderKind>
391getAtomicMemoryOrder(semantics::SemanticsContext &semaCtx,
392 const omp::List<omp::Clause> &clauses,
393 const semantics::Scope &scope) {
394 for (const omp::Clause &clause : clauses) {
395 if (auto maybeKind = getMemoryOrderKind(clause.id))
396 return *maybeKind;
397 }
398
399 if (auto maybeKind = getMemoryOrderFromRequires(scope))
400 return *maybeKind;
401
402 return getDefaultAtomicMemOrder(semaCtx);
403}
404
405static mlir::omp::ClauseMemoryOrderKindAttr
406makeMemOrderAttr(lower::AbstractConverter &converter,
407 std::optional<mlir::omp::ClauseMemoryOrderKind> maybeKind) {
408 if (maybeKind) {
409 return mlir::omp::ClauseMemoryOrderKindAttr::get(
410 converter.getFirOpBuilder().getContext(), *maybeKind);
411 }
412 return nullptr;
413}
414
415static bool replaceArgs(semantics::SomeExpr &expr,
416 evaluate::ActualArguments &&newArgs) {
417 return ArgumentReplacer(std::move(newArgs))(expr);
418}
419
420static semantics::SomeExpr makeCall(const evaluate::DynamicType &type,
421 const evaluate::ProcedureDesignator &proc,
422 const evaluate::ActualArguments &args) {
423 return WithType(type).visit([&](auto &&s) -> semantics::SomeExpr {
424 using Type = typename llvm::remove_cvref_t<decltype(s)>::type;
425 return evaluate::AsGenericExpr(
426 evaluate::FunctionRef<Type>(AsRvalue(proc), AsRvalue(args)));
427 });
428}
429
430static const evaluate::ProcedureDesignator &
431getProcedureDesignator(const semantics::SomeExpr &call) {
432 const evaluate::ProcedureDesignator *proc = GetProc{}(call);
433 assert(proc && "Call has no procedure designator");
434 return *proc;
435}
436
437static semantics::SomeExpr //
438genReducedMinMax(const semantics::SomeExpr &orig,
439 const semantics::SomeExpr *atomArg,
440 const std::vector<semantics::SomeExpr> &args) {
441 // Take a list of arguments to a min/max operation, e.g. [a0, a1, ...]
442 // One of the a_i's, say a_t, must be atomArg.
443 // Generate tmp = min/max(a0, a1, ... [except a_t]). Then generate
444 // call = min/max(a_t, tmp).
445 // Return "call".
446
447 // The min/max intrinsics have 2 mandatory arguments, the rest is optional.
448 // Make sure that the "tmp = min/max(...)" doesn't promote an optional
449 // argument to a non-optional position. This could happen if a_t is at
450 // position 0 or 1.
451 if (args.size() <= 2)
452 return orig;
453
454 evaluate::ActualArguments nonAtoms;
455
456 auto AsActual = [](const semantics::SomeExpr &x) {
457 semantics::SomeExpr copy = x;
458 return evaluate::ActualArgument(std::move(copy));
459 };
460 // Semantic checks guarantee that the "atom" shows exactly once in the
461 // argument list (with potential conversions around it).
462 // For the first two (non-optional) arguments, if "atom" is among them,
463 // replace it with another occurrence of the other non-optional argument.
464 if (atomArg == &args[0]) {
465 // (atom, x, y...) -> (x, x, y...)
466 nonAtoms.push_back(AsActual(args[1]));
467 nonAtoms.push_back(AsActual(args[1]));
468 } else if (atomArg == &args[1]) {
469 // (x, atom, y...) -> (x, x, y...)
470 nonAtoms.push_back(AsActual(args[0]));
471 nonAtoms.push_back(AsActual(args[0]));
472 } else {
473 // (x, y, z...) -> unchanged
474 nonAtoms.push_back(AsActual(args[0]));
475 nonAtoms.push_back(AsActual(args[1]));
476 }
477
478 // The rest of arguments are optional, so we can just skip "atom".
479 for (size_t i = 2, e = args.size(); i != e; ++i) {
480 if (atomArg != &args[i])
481 nonAtoms.push_back(AsActual(args[i]));
482 }
483
484 // The type of the intermediate min/max is the same as the type of its
485 // arguments, which may be different from the type of the original
486 // expression. The original expression may have additional coverts.
487 auto tmp =
488 makeCall(*atomArg->GetType(), getProcedureDesignator(orig), nonAtoms);
489 semantics::SomeExpr call = orig;
490 replaceArgs(call, {AsActual(*atomArg), AsActual(tmp)});
491 return call;
492}
493
494static mlir::Operation * //
495genAtomicRead(lower::AbstractConverter &converter,
496 semantics::SemanticsContext &semaCtx, mlir::Location loc,
497 lower::StatementContext &stmtCtx, mlir::Value atomAddr,
498 const semantics::SomeExpr &atom,
499 const evaluate::Assignment &assign, mlir::IntegerAttr hint,
500 std::optional<mlir::omp::ClauseMemoryOrderKind> memOrder,
501 fir::FirOpBuilder::InsertPoint preAt,
502 fir::FirOpBuilder::InsertPoint atomicAt,
503 fir::FirOpBuilder::InsertPoint postAt) {
504 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
505 builder.restoreInsertionPoint(preAt);
506
507 // If the atomic clause is read then the memory-order clause must
508 // not be release.
509 if (memOrder) {
510 if (*memOrder == mlir::omp::ClauseMemoryOrderKind::Release) {
511 // Reset it back to the default.
512 memOrder = getDefaultAtomicMemOrder(semaCtx);
513 } else if (*memOrder == mlir::omp::ClauseMemoryOrderKind::Acq_rel) {
514 // The MLIR verifier doesn't like acq_rel either.
515 memOrder = mlir::omp::ClauseMemoryOrderKind::Acquire;
516 }
517 }
518
519 mlir::Value storeAddr =
520 fir::getBase(converter.genExprAddr(assign.lhs, stmtCtx, &loc));
521 mlir::Type atomType = fir::unwrapRefType(atomAddr.getType());
522 mlir::Type storeType = fir::unwrapRefType(storeAddr.getType());
523
524 mlir::Value toAddr = [&]() {
525 if (atomType == storeType)
526 return storeAddr;
527 return builder.createTemporary(loc, atomType, ".tmp.atomval");
528 }();
529
530 builder.restoreInsertionPoint(atomicAt);
531 mlir::Operation *op = builder.create<mlir::omp::AtomicReadOp>(
532 loc, atomAddr, toAddr, mlir::TypeAttr::get(atomType), hint,
533 makeMemOrderAttr(converter, memOrder));
534
535 if (atomType != storeType) {
536 lower::ExprToValueMap overrides;
537 // The READ operation could be a part of UPDATE CAPTURE, so make sure
538 // we don't emit extra code into the body of the atomic op.
539 builder.restoreInsertionPoint(postAt);
540 mlir::Value load = builder.create<fir::LoadOp>(loc, toAddr);
541 overrides.try_emplace(&atom, load);
542
543 converter.overrideExprValues(&overrides);
544 mlir::Value value =
545 fir::getBase(converter.genExprValue(assign.rhs, stmtCtx, &loc));
546 converter.resetExprOverrides();
547
548 builder.create<fir::StoreOp>(loc, value, storeAddr);
549 }
550 return op;
551}
552
553static mlir::Operation * //
554genAtomicWrite(lower::AbstractConverter &converter,
555 semantics::SemanticsContext &semaCtx, mlir::Location loc,
556 lower::StatementContext &stmtCtx, mlir::Value atomAddr,
557 const semantics::SomeExpr &atom,
558 const evaluate::Assignment &assign, mlir::IntegerAttr hint,
559 std::optional<mlir::omp::ClauseMemoryOrderKind> memOrder,
560 fir::FirOpBuilder::InsertPoint preAt,
561 fir::FirOpBuilder::InsertPoint atomicAt,
562 fir::FirOpBuilder::InsertPoint postAt) {
563 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
564 builder.restoreInsertionPoint(preAt);
565
566 // If the atomic clause is write then the memory-order clause must
567 // not be acquire.
568 if (memOrder) {
569 if (*memOrder == mlir::omp::ClauseMemoryOrderKind::Acquire) {
570 // Reset it back to the default.
571 memOrder = getDefaultAtomicMemOrder(semaCtx);
572 } else if (*memOrder == mlir::omp::ClauseMemoryOrderKind::Acq_rel) {
573 // The MLIR verifier doesn't like acq_rel either.
574 memOrder = mlir::omp::ClauseMemoryOrderKind::Release;
575 }
576 }
577
578 mlir::Value value =
579 fir::getBase(converter.genExprValue(assign.rhs, stmtCtx, &loc));
580 mlir::Type atomType = fir::unwrapRefType(atomAddr.getType());
581 mlir::Value converted = builder.createConvert(loc, atomType, value);
582
583 builder.restoreInsertionPoint(atomicAt);
584 mlir::Operation *op = builder.create<mlir::omp::AtomicWriteOp>(
585 loc, atomAddr, converted, hint, makeMemOrderAttr(converter, memOrder));
586 return op;
587}
588
589static mlir::Operation *
590genAtomicUpdate(lower::AbstractConverter &converter,
591 semantics::SemanticsContext &semaCtx, mlir::Location loc,
592 lower::StatementContext &stmtCtx, mlir::Value atomAddr,
593 const semantics::SomeExpr &atom,
594 const evaluate::Assignment &assign, mlir::IntegerAttr hint,
595 std::optional<mlir::omp::ClauseMemoryOrderKind> memOrder,
596 fir::FirOpBuilder::InsertPoint preAt,
597 fir::FirOpBuilder::InsertPoint atomicAt,
598 fir::FirOpBuilder::InsertPoint postAt) {
599 lower::ExprToValueMap overrides;
600 lower::StatementContext naCtx;
601 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
602 builder.restoreInsertionPoint(preAt);
603
604 mlir::Type atomType = fir::unwrapRefType(atomAddr.getType());
605
606 // This must exist by now.
607 semantics::SomeExpr rhs = assign.rhs;
608 semantics::SomeExpr input = *evaluate::GetConvertInput(rhs);
609 auto [opcode, args] = evaluate::GetTopLevelOperation(input);
610 assert(!args.empty() && "Update operation without arguments");
611
612 // Pass args as an argument to avoid capturing a structured binding.
613 const semantics::SomeExpr *atomArg = [&](auto &args) {
614 for (const semantics::SomeExpr &e : args) {
615 if (evaluate::IsSameOrConvertOf(e, atom))
616 return &e;
617 }
618 llvm_unreachable("Atomic variable not in argument list");
619 }(args);
620
621 if (opcode == evaluate::operation::Operator::Min ||
622 opcode == evaluate::operation::Operator::Max) {
623 // Min and max operations are expanded inline, so reduce them to
624 // operations with exactly two (non-optional) arguments.
625 rhs = genReducedMinMax(rhs, atomArg, args);
626 input = *evaluate::GetConvertInput(rhs);
627 std::tie(opcode, args) = evaluate::GetTopLevelOperation(input);
628 atomArg = nullptr; // No longer valid.
629 }
630 for (auto &arg : args) {
631 if (!evaluate::IsSameOrConvertOf(arg, atom)) {
632 mlir::Value val = fir::getBase(converter.genExprValue(arg, naCtx, &loc));
633 overrides.try_emplace(&arg, val);
634 }
635 }
636
637 builder.restoreInsertionPoint(atomicAt);
638 auto updateOp = builder.create<mlir::omp::AtomicUpdateOp>(
639 loc, atomAddr, hint, makeMemOrderAttr(converter, memOrder));
640
641 mlir::Region &region = updateOp->getRegion(0);
642 mlir::Block *block = builder.createBlock(&region, {}, {atomType}, {loc});
643 mlir::Value localAtom = fir::getBase(block->getArgument(0));
644 overrides.try_emplace(&atom, localAtom);
645
646 converter.overrideExprValues(&overrides);
647 mlir::Value updated =
648 fir::getBase(converter.genExprValue(rhs, stmtCtx, &loc));
649 mlir::Value converted = builder.createConvert(loc, atomType, updated);
650 builder.create<mlir::omp::YieldOp>(loc, converted);
651 converter.resetExprOverrides();
652
653 builder.restoreInsertionPoint(postAt); // For naCtx cleanups
654 return updateOp;
655}
656
657static mlir::Operation *
658genAtomicOperation(lower::AbstractConverter &converter,
659 semantics::SemanticsContext &semaCtx, mlir::Location loc,
660 lower::StatementContext &stmtCtx, int action,
661 mlir::Value atomAddr, const semantics::SomeExpr &atom,
662 const evaluate::Assignment &assign, mlir::IntegerAttr hint,
663 std::optional<mlir::omp::ClauseMemoryOrderKind> memOrder,
664 fir::FirOpBuilder::InsertPoint preAt,
665 fir::FirOpBuilder::InsertPoint atomicAt,
666 fir::FirOpBuilder::InsertPoint postAt) {
667 if (isPointerAssignment(assign)) {
668 TODO(loc, "Code generation for pointer assignment is not implemented yet");
669 }
670
671 // This function and the functions called here do not preserve the
672 // builder's insertion point, or set it to anything specific.
673 switch (action) {
674 case parser::OpenMPAtomicConstruct::Analysis::Read:
675 return genAtomicRead(converter, semaCtx, loc, stmtCtx, atomAddr, atom,
676 assign, hint, memOrder, preAt, atomicAt, postAt);
677 case parser::OpenMPAtomicConstruct::Analysis::Write:
678 return genAtomicWrite(converter, semaCtx, loc, stmtCtx, atomAddr, atom,
679 assign, hint, memOrder, preAt, atomicAt, postAt);
680 case parser::OpenMPAtomicConstruct::Analysis::Update:
681 return genAtomicUpdate(converter, semaCtx, loc, stmtCtx, atomAddr, atom,
682 assign, hint, memOrder, preAt, atomicAt, postAt);
683 default:
684 return nullptr;
685 }
686}
687
688void Fortran::lower::omp::lowerAtomic(
689 AbstractConverter &converter, SymMap &symTable,
690 semantics::SemanticsContext &semaCtx, pft::Evaluation &eval,
691 const parser::OpenMPAtomicConstruct &construct) {
692 auto get = [](auto &&typedWrapper) -> decltype(&*typedWrapper.get()->v) {
693 if (auto *maybe = typedWrapper.get(); maybe && maybe->v) {
694 return &*maybe->v;
695 } else {
696 return nullptr;
697 }
698 };
699
700 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
701 auto &dirSpec = std::get<parser::OmpDirectiveSpecification>(construct.t);
702 omp::List<omp::Clause> clauses = makeClauses(dirSpec.Clauses(), semaCtx);
703 lower::StatementContext stmtCtx;
704
705 const parser::OpenMPAtomicConstruct::Analysis &analysis = construct.analysis;
706 if (DumpAtomicAnalysis)
707 dumpAtomicAnalysis(analysis);
708
709 const semantics::SomeExpr &atom = *get(analysis.atom);
710 mlir::Location loc = converter.genLocation(construct.source);
711 mlir::Value atomAddr =
712 fir::getBase(converter.genExprAddr(atom, stmtCtx, &loc));
713 mlir::IntegerAttr hint = getAtomicHint(converter, clauses);
714 std::optional<mlir::omp::ClauseMemoryOrderKind> memOrder =
715 getAtomicMemoryOrder(semaCtx, clauses,
716 semaCtx.FindScope(construct.source));
717
718 if (auto *cond = get(analysis.cond)) {
719 (void)cond;
720 TODO(loc, "OpenMP ATOMIC COMPARE");
721 } else {
722 int action0 = analysis.op0.what & analysis.Action;
723 int action1 = analysis.op1.what & analysis.Action;
724 mlir::Operation *captureOp = nullptr;
725 fir::FirOpBuilder::InsertPoint preAt = builder.saveInsertionPoint();
726 fir::FirOpBuilder::InsertPoint atomicAt, postAt;
727
728 if (construct.IsCapture()) {
729 // Capturing operation.
730 assert(action0 != analysis.None && action1 != analysis.None &&
731 "Expexcing two actions");
732 (void)action0;
733 (void)action1;
734 captureOp = builder.create<mlir::omp::AtomicCaptureOp>(
735 loc, hint, makeMemOrderAttr(converter, memOrder));
736 // Set the non-atomic insertion point to before the atomic.capture.
737 preAt = getInsertionPointBefore(captureOp);
738
739 mlir::Block *block = builder.createBlock(&captureOp->getRegion(0));
740 builder.setInsertionPointToEnd(block);
741 // Set the atomic insertion point to before the terminator inside
742 // atomic.capture.
743 mlir::Operation *term = builder.create<mlir::omp::TerminatorOp>(loc);
744 atomicAt = getInsertionPointBefore(term);
745 postAt = getInsertionPointAfter(captureOp);
746 hint = nullptr;
747 memOrder = std::nullopt;
748 } else {
749 // Non-capturing operation.
750 assert(action0 != analysis.None && action1 == analysis.None &&
751 "Expexcing single action");
752 assert(!(analysis.op0.what & analysis.Condition));
753 postAt = atomicAt = preAt;
754 }
755
756 // The builder's insertion point needs to be specifically set before
757 // each call to `genAtomicOperation`.
758 mlir::Operation *firstOp = genAtomicOperation(
759 converter, semaCtx, loc, stmtCtx, analysis.op0.what, atomAddr, atom,
760 *get(analysis.op0.assign), hint, memOrder, preAt, atomicAt, postAt);
761 assert(firstOp && "Should have created an atomic operation");
762 atomicAt = getInsertionPointAfter(firstOp);
763
764 mlir::Operation *secondOp = nullptr;
765 if (analysis.op1.what != analysis.None) {
766 secondOp = genAtomicOperation(
767 converter, semaCtx, loc, stmtCtx, analysis.op1.what, atomAddr, atom,
768 *get(analysis.op1.assign), hint, memOrder, preAt, atomicAt, postAt);
769 }
770
771 if (construct.IsCapture()) {
772 // If this is a capture operation, the first/second ops will be inside
773 // of it. Set the insertion point to past the capture op itself.
774 builder.restoreInsertionPoint(postAt);
775 } else {
776 if (secondOp) {
777 builder.setInsertionPointAfter(secondOp);
778 } else {
779 builder.setInsertionPointAfter(firstOp);
780 }
781 }
782 }
783}
784

source code of flang/lib/Lower/OpenMP/Atomic.cpp