Warning: This file is not a C or C++ file. It does not have highlighting.

1//===-- IterationSpace.h ----------------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef FORTRAN_LOWER_ITERATIONSPACE_H
14#define FORTRAN_LOWER_ITERATIONSPACE_H
15
16#include "flang/Evaluate/tools.h"
17#include "flang/Lower/StatementContext.h"
18#include "flang/Lower/SymbolMap.h"
19#include "flang/Optimizer/Builder/FIRBuilder.h"
20#include <optional>
21
22namespace llvm {
23class raw_ostream;
24}
25
26namespace Fortran {
27namespace evaluate {
28struct SomeType;
29template <typename>
30class Expr;
31} // namespace evaluate
32
33namespace lower {
34
35using FrontEndExpr = const evaluate::Expr<evaluate::SomeType> *;
36using FrontEndSymbol = const semantics::Symbol *;
37
38class AbstractConverter;
39
40} // namespace lower
41} // namespace Fortran
42
43namespace Fortran::lower {
44
45/// Abstraction of the iteration space for building the elemental compute loop
46/// of an array(-like) statement.
47class IterationSpace {
48public:
49 IterationSpace() = default;
50
51 template <typename A>
52 explicit IterationSpace(mlir::Value inArg, mlir::Value outRes,
53 llvm::iterator_range<A> range)
54 : inArg{inArg}, outRes{outRes}, indices{range.begin(), range.end()} {}
55
56 explicit IterationSpace(const IterationSpace &from,
57 llvm::ArrayRef<mlir::Value> idxs)
58 : inArg(from.inArg), outRes(from.outRes), element(from.element),
59 indices(idxs.begin(), idxs.end()) {}
60
61 /// Create a copy of the \p from IterationSpace and prepend the \p prefix
62 /// values and append the \p suffix values, respectively.
63 explicit IterationSpace(const IterationSpace &from,
64 llvm::ArrayRef<mlir::Value> prefix,
65 llvm::ArrayRef<mlir::Value> suffix)
66 : inArg(from.inArg), outRes(from.outRes), element(from.element) {
67 indices.assign(prefix.begin(), prefix.end());
68 indices.append(from.indices.begin(), from.indices.end());
69 indices.append(suffix.begin(), suffix.end());
70 }
71
72 bool empty() const { return indices.empty(); }
73
74 /// This is the output value as it appears as an argument in the innermost
75 /// loop in the nest. The output value is threaded through the loop (and
76 /// conditionals) to maintain proper SSA form.
77 mlir::Value innerArgument() const { return inArg; }
78
79 /// This is the output value as it appears as an output value from the
80 /// outermost loop in the loop nest. The output value is threaded through the
81 /// loop (and conditionals) to maintain proper SSA form.
82 mlir::Value outerResult() const { return outRes; }
83
84 /// Returns a vector for the iteration space. This vector is used to access
85 /// elements of arrays in the compute loop.
86 llvm::SmallVector<mlir::Value> iterVec() const { return indices; }
87
88 mlir::Value iterValue(std::size_t i) const {
89 assert(i < indices.size());
90 return indices[i];
91 }
92
93 /// Set (rewrite) the Value at a given index.
94 void setIndexValue(std::size_t i, mlir::Value v) {
95 assert(i < indices.size());
96 indices[i] = v;
97 }
98
99 void setIndexValues(llvm::ArrayRef<mlir::Value> vals) {
100 indices.assign(vals.begin(), vals.end());
101 }
102
103 void insertIndexValue(std::size_t i, mlir::Value av) {
104 assert(i <= indices.size());
105 indices.insert(indices.begin() + i, av);
106 }
107
108 /// Set the `element` value. This is the SSA value that corresponds to an
109 /// element of the resultant array value.
110 void setElement(fir::ExtendedValue &&ele) {
111 assert(!fir::getBase(element) && "result element already set");
112 element = ele;
113 }
114
115 /// Get the value that will be merged into the resultant array. This is the
116 /// computed value that will be stored to the lhs of the assignment.
117 mlir::Value getElement() const {
118 assert(fir::getBase(element) && "element must be set");
119 return fir::getBase(element);
120 }
121
122 /// Get the element as an extended value.
123 fir::ExtendedValue elementExv() const { return element; }
124
125 void clearIndices() { indices.clear(); }
126
127private:
128 mlir::Value inArg;
129 mlir::Value outRes;
130 fir::ExtendedValue element;
131 llvm::SmallVector<mlir::Value> indices;
132};
133
134using GenerateElementalArrayFunc =
135 std::function<fir::ExtendedValue(const IterationSpace &)>;
136
137template <typename A>
138class StackableConstructExpr {
139public:
140 bool empty() const { return stack.empty(); }
141
142 void growStack() { stack.push_back(A{}); }
143
144 /// Bind a front-end expression to a closure.
145 void bind(FrontEndExpr e, GenerateElementalArrayFunc &&fun) {
146 vmap.insert({e, std::move(fun)});
147 }
148
149 /// Replace the binding of front-end expression `e` with a new closure.
150 void rebind(FrontEndExpr e, GenerateElementalArrayFunc &&fun) {
151 vmap.erase(e);
152 bind(e, std::move(fun));
153 }
154
155 /// Get the closure bound to the front-end expression, `e`.
156 GenerateElementalArrayFunc getBoundClosure(FrontEndExpr e) const {
157 if (!vmap.count(e))
158 llvm::report_fatal_error(
159 "evaluate::Expr is not in the map of lowered mask expressions");
160 return vmap.lookup(e);
161 }
162
163 /// Has the front-end expression, `e`, been lowered and bound?
164 bool isLowered(FrontEndExpr e) const { return vmap.count(e); }
165
166 StatementContext &stmtContext() { return stmtCtx; }
167
168protected:
169 void shrinkStack() {
170 assert(!empty());
171 stack.pop_back();
172 if (empty()) {
173 stmtCtx.finalizeAndReset();
174 vmap.clear();
175 }
176 }
177
178 // The stack for the construct information.
179 llvm::SmallVector<A> stack;
180
181 // Map each mask expression back to the temporary holding the initial
182 // evaluation results.
183 llvm::DenseMap<FrontEndExpr, GenerateElementalArrayFunc> vmap;
184
185 // Inflate the statement context for the entire construct. We have to cache
186 // the mask expression results, which are always evaluated first, across the
187 // entire construct.
188 StatementContext stmtCtx;
189};
190
191class ImplicitIterSpace;
192llvm::raw_ostream &operator<<(llvm::raw_ostream &, const ImplicitIterSpace &);
193
194/// All array expressions have an implicit iteration space, which is isomorphic
195/// to the shape of the base array that facilitates the expression having a
196/// non-zero rank. This implied iteration space may be conditionalized
197/// (disjunctively) with an if-elseif-else like structure, specifically
198/// Fortran's WHERE construct.
199///
200/// This class is used in the bridge to collect the expressions from the
201/// front end (the WHERE construct mask expressions), forward them for lowering
202/// as array expressions in an "evaluate once" (copy-in, copy-out) semantics.
203/// See 10.2.3.2p3, 10.2.3.2p13, etc.
204class ImplicitIterSpace
205 : public StackableConstructExpr<llvm::SmallVector<FrontEndExpr>> {
206public:
207 using Base = StackableConstructExpr<llvm::SmallVector<FrontEndExpr>>;
208 using FrontEndMaskExpr = FrontEndExpr;
209
210 friend llvm::raw_ostream &operator<<(llvm::raw_ostream &,
211 const ImplicitIterSpace &);
212
213 LLVM_DUMP_METHOD void dump() const;
214
215 void append(FrontEndMaskExpr e) {
216 assert(!empty());
217 getMasks().back().push_back(e);
218 }
219
220 llvm::SmallVector<FrontEndMaskExpr> getExprs() const {
221 llvm::SmallVector<FrontEndMaskExpr> maskList = getMasks()[0];
222 for (size_t i = 1, d = getMasks().size(); i < d; ++i)
223 maskList.append(getMasks()[i].begin(), getMasks()[i].end());
224 return maskList;
225 }
226
227 /// Add a variable binding, `var`, along with its shape for the mask
228 /// expression `exp`.
229 void addMaskVariable(FrontEndExpr exp, mlir::Value var, mlir::Value shape,
230 mlir::Value header) {
231 maskVarMap.try_emplace(exp, std::make_tuple(var, shape, header));
232 }
233
234 /// Lookup the variable corresponding to the temporary buffer that contains
235 /// the mask array expression results.
236 mlir::Value lookupMaskVariable(FrontEndExpr exp) {
237 return std::get<0>(maskVarMap.lookup(exp));
238 }
239
240 /// Lookup the variable containing the shape vector for the mask array
241 /// expression results.
242 mlir::Value lookupMaskShapeBuffer(FrontEndExpr exp) {
243 return std::get<1>(maskVarMap.lookup(exp));
244 }
245
246 mlir::Value lookupMaskHeader(FrontEndExpr exp) {
247 return std::get<2>(maskVarMap.lookup(exp));
248 }
249
250 // Stack of WHERE constructs, each building a list of mask expressions.
251 llvm::SmallVector<llvm::SmallVector<FrontEndMaskExpr>> &getMasks() {
252 return stack;
253 }
254 const llvm::SmallVector<llvm::SmallVector<FrontEndMaskExpr>> &
255 getMasks() const {
256 return stack;
257 }
258
259 // Cleanup at the end of a WHERE statement or construct.
260 void shrinkStack() {
261 Base::shrinkStack();
262 if (stack.empty())
263 maskVarMap.clear();
264 }
265
266private:
267 llvm::DenseMap<FrontEndExpr,
268 std::tuple<mlir::Value, mlir::Value, mlir::Value>>
269 maskVarMap;
270};
271
272class ExplicitIterSpace;
273llvm::raw_ostream &operator<<(llvm::raw_ostream &, const ExplicitIterSpace &);
274
275/// Create all the array_load ops for the explicit iteration space context. The
276/// nest of FORALLs must have been analyzed a priori.
277void createArrayLoads(AbstractConverter &converter, ExplicitIterSpace &esp,
278 SymMap &symMap);
279
280/// Create the array_merge_store ops after the explicit iteration space context
281/// is conmpleted.
282void createArrayMergeStores(AbstractConverter &converter,
283 ExplicitIterSpace &esp);
284using ExplicitSpaceArrayBases =
285 std::variant<FrontEndSymbol, const evaluate::Component *,
286 const evaluate::ArrayRef *>;
287
288unsigned getHashValue(const ExplicitSpaceArrayBases &x);
289bool isEqual(const ExplicitSpaceArrayBases &x,
290 const ExplicitSpaceArrayBases &y);
291
292} // namespace Fortran::lower
293
294namespace llvm {
295template <>
296struct DenseMapInfo<Fortran::lower::ExplicitSpaceArrayBases> {
297 static inline Fortran::lower::ExplicitSpaceArrayBases getEmptyKey() {
298 return reinterpret_cast<Fortran::lower::FrontEndSymbol>(~0);
299 }
300 static inline Fortran::lower::ExplicitSpaceArrayBases getTombstoneKey() {
301 return reinterpret_cast<Fortran::lower::FrontEndSymbol>(~0 - 1);
302 }
303 static unsigned
304 getHashValue(const Fortran::lower::ExplicitSpaceArrayBases &v) {
305 return Fortran::lower::getHashValue(v);
306 }
307 static bool isEqual(const Fortran::lower::ExplicitSpaceArrayBases &lhs,
308 const Fortran::lower::ExplicitSpaceArrayBases &rhs) {
309 return Fortran::lower::isEqual(lhs, rhs);
310 }
311};
312} // namespace llvm
313
314namespace Fortran::lower {
315/// Fortran also allows arrays to be evaluated under constructs which allow the
316/// user to explicitly specify the iteration space using concurrent-control
317/// expressions. These constructs allow the user to define both an iteration
318/// space and explicit access vectors on arrays. These need not be isomorphic.
319/// The explicit iteration spaces may be conditionalized (conjunctively) with an
320/// "and" structure and may be found in FORALL (and DO CONCURRENT) constructs.
321///
322/// This class is used in the bridge to collect a stack of lists of
323/// concurrent-control expressions to be used to generate the iteration space
324/// and associated masks (if any) for a set of nested FORALL constructs around
325/// assignment and WHERE constructs.
326class ExplicitIterSpace {
327public:
328 using IterSpaceDim =
329 std::tuple<FrontEndSymbol, FrontEndExpr, FrontEndExpr, FrontEndExpr>;
330 using ConcurrentSpec =
331 std::pair<llvm::SmallVector<IterSpaceDim>, FrontEndExpr>;
332 using ArrayBases = ExplicitSpaceArrayBases;
333
334 friend void createArrayLoads(AbstractConverter &converter,
335 ExplicitIterSpace &esp, SymMap &symMap);
336 friend void createArrayMergeStores(AbstractConverter &converter,
337 ExplicitIterSpace &esp);
338
339 /// Is a FORALL context presently active?
340 /// If we are lowering constructs/statements nested within a FORALL, then a
341 /// FORALL context is active.
342 bool isActive() const { return forallContextOpen != 0; }
343
344 /// Get the statement context.
345 StatementContext &stmtContext() { return stmtCtx; }
346
347 //===--------------------------------------------------------------------===//
348 // Analysis support
349 //===--------------------------------------------------------------------===//
350
351 /// Open a new construct. The analysis phase starts here.
352 void pushLevel();
353
354 /// Close the construct.
355 void popLevel();
356
357 /// Add new concurrent header control variable symbol.
358 void addSymbol(FrontEndSymbol sym);
359
360 /// Collect array bases from the expression, `x`.
361 void exprBase(FrontEndExpr x, bool lhs);
362
363 /// Called at the end of a assignment statement.
364 void endAssign();
365
366 /// Return all the active control variables on the stack.
367 llvm::SmallVector<FrontEndSymbol> collectAllSymbols();
368
369 //===--------------------------------------------------------------------===//
370 // Code gen support
371 //===--------------------------------------------------------------------===//
372
373 /// Enter a FORALL context.
374 void enter() { forallContextOpen++; }
375
376 /// Leave a FORALL context.
377 void leave();
378
379 void pushLoopNest(std::function<void()> lambda) {
380 ccLoopNest.push_back(lambda);
381 }
382
383 /// Get the inner arguments that correspond to the output arrays.
384 mlir::ValueRange getInnerArgs() const { return innerArgs; }
385
386 /// Set the inner arguments for the next loop level.
387 void setInnerArgs(llvm::ArrayRef<mlir::BlockArgument> args) {
388 innerArgs.clear();
389 for (auto &arg : args)
390 innerArgs.push_back(arg);
391 }
392
393 /// Reset the outermost `array_load` arguments to the loop nest.
394 void resetInnerArgs() { innerArgs = initialArgs; }
395
396 /// Capture the current outermost loop.
397 void setOuterLoop(fir::DoLoopOp loop) {
398 clearLoops();
399 outerLoop = loop;
400 }
401
402 /// Sets the inner loop argument at position \p offset to \p val.
403 void setInnerArg(size_t offset, mlir::Value val) {
404 assert(offset < innerArgs.size());
405 innerArgs[offset] = val;
406 }
407
408 /// Get the types of the output arrays.
409 llvm::SmallVector<mlir::Type> innerArgTypes() const {
410 llvm::SmallVector<mlir::Type> result;
411 for (auto &arg : innerArgs)
412 result.push_back(arg.getType());
413 return result;
414 }
415
416 /// Create a binding between an Ev::Expr node pointer and a fir::array_load
417 /// op. This bindings will be used when generating the IR.
418 void bindLoad(ArrayBases base, fir::ArrayLoadOp load) {
419 loadBindings.try_emplace(std::move(base), load);
420 }
421
422 fir::ArrayLoadOp findBinding(const ArrayBases &base) {
423 return loadBindings.lookup(base);
424 }
425
426 /// `load` must be a LHS array_load. Returns `std::nullopt` on error.
427 std::optional<size_t> findArgPosition(fir::ArrayLoadOp load);
428
429 bool isLHS(fir::ArrayLoadOp load) {
430 return findArgPosition(load).has_value();
431 }
432
433 /// `load` must be a LHS array_load. Determine the threaded inner argument
434 /// corresponding to this load.
435 mlir::Value findArgumentOfLoad(fir::ArrayLoadOp load) {
436 if (auto opt = findArgPosition(load))
437 return innerArgs[*opt];
438 llvm_unreachable("array load argument not found");
439 }
440
441 size_t argPosition(mlir::Value arg) {
442 for (auto i : llvm::enumerate(innerArgs))
443 if (arg == i.value())
444 return i.index();
445 llvm_unreachable("inner argument value was not found");
446 }
447
448 std::optional<fir::ArrayLoadOp> getLhsLoad(size_t i) {
449 assert(i < lhsBases.size());
450 if (lhsBases[counter])
451 return findBinding(*lhsBases[counter]);
452 return std::nullopt;
453 }
454
455 /// Return the outermost loop in this FORALL nest.
456 fir::DoLoopOp getOuterLoop() {
457 assert(outerLoop.has_value());
458 return *outerLoop;
459 }
460
461 /// Return the statement context for the entire, outermost FORALL construct.
462 StatementContext &outermostContext() { return outerContext; }
463
464 /// Generate the explicit loop nest.
465 void genLoopNest() {
466 for (auto &lambda : ccLoopNest)
467 lambda();
468 }
469
470 /// Clear the array_load bindings.
471 void resetBindings() { loadBindings.clear(); }
472
473 /// Get the current counter value.
474 std::size_t getCounter() const { return counter; }
475
476 /// Increment the counter value to the next assignment statement.
477 void incrementCounter() { counter++; }
478
479 bool isOutermostForall() const {
480 assert(forallContextOpen);
481 return forallContextOpen == 1;
482 }
483
484 void attachLoopCleanup(std::function<void(fir::FirOpBuilder &builder)> fn) {
485 if (!loopCleanup) {
486 loopCleanup = fn;
487 return;
488 }
489 std::function<void(fir::FirOpBuilder &)> oldFn = *loopCleanup;
490 loopCleanup = [=](fir::FirOpBuilder &builder) {
491 oldFn(builder);
492 fn(builder);
493 };
494 }
495
496 // LLVM standard dump method.
497 LLVM_DUMP_METHOD void dump() const;
498
499 // Pretty-print.
500 friend llvm::raw_ostream &operator<<(llvm::raw_ostream &,
501 const ExplicitIterSpace &);
502
503 /// Finalize the current body statement context.
504 void finalizeContext() { stmtCtx.finalizeAndReset(); }
505
506 void appendLoops(const llvm::SmallVector<fir::DoLoopOp> &loops) {
507 loopStack.push_back(loops);
508 }
509
510 void clearLoops() { loopStack.clear(); }
511
512 llvm::SmallVector<llvm::SmallVector<fir::DoLoopOp>> getLoopStack() const {
513 return loopStack;
514 }
515
516private:
517 /// Cleanup the analysis results.
518 void conditionalCleanup();
519
520 StatementContext outerContext;
521
522 // A stack of lists of front-end symbols.
523 llvm::SmallVector<llvm::SmallVector<FrontEndSymbol>> symbolStack;
524 llvm::SmallVector<std::optional<ArrayBases>> lhsBases;
525 llvm::SmallVector<llvm::SmallVector<ArrayBases>> rhsBases;
526 llvm::DenseMap<ArrayBases, fir::ArrayLoadOp> loadBindings;
527
528 // Stack of lambdas to create the loop nest.
529 llvm::SmallVector<std::function<void()>> ccLoopNest;
530
531 // Assignment statement context (inside the loop nest).
532 StatementContext stmtCtx;
533 llvm::SmallVector<mlir::Value> innerArgs;
534 llvm::SmallVector<mlir::Value> initialArgs;
535 std::optional<fir::DoLoopOp> outerLoop;
536 llvm::SmallVector<llvm::SmallVector<fir::DoLoopOp>> loopStack;
537 std::optional<std::function<void(fir::FirOpBuilder &)>> loopCleanup;
538 std::size_t forallContextOpen = 0;
539 std::size_t counter = 0;
540};
541
542/// Is there a Symbol in common between the concurrent header set and the set
543/// of symbols in the expression?
544template <typename A>
545bool symbolSetsIntersect(llvm::ArrayRef<FrontEndSymbol> ctrlSet,
546 const A &exprSyms) {
547 for (const auto &sym : exprSyms)
548 if (llvm::is_contained(ctrlSet, &sym.get()))
549 return true;
550 return false;
551}
552
553/// Determine if the subscript expression symbols from an Ev::ArrayRef
554/// intersects with the set of concurrent control symbols, `ctrlSet`.
555template <typename A>
556bool symbolsIntersectSubscripts(llvm::ArrayRef<FrontEndSymbol> ctrlSet,
557 const A &subscripts) {
558 for (auto &sub : subscripts) {
559 if (const auto *expr =
560 std::get_if<evaluate::IndirectSubscriptIntegerExpr>(&sub.u))
561 if (symbolSetsIntersect(ctrlSet, evaluate::CollectSymbols(expr->value())))
562 return true;
563 }
564 return false;
565}
566
567} // namespace Fortran::lower
568
569#endif // FORTRAN_LOWER_ITERATIONSPACE_H
570

Warning: This file is not a C or C++ file. It does not have highlighting.

source code of flang/include/flang/Lower/IterationSpace.h