1 | //===-- Lower/OpenMP/ReductionProcessor.h -----------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #ifndef FORTRAN_LOWER_REDUCTIONPROCESSOR_H |
14 | #define FORTRAN_LOWER_REDUCTIONPROCESSOR_H |
15 | |
16 | #include "Clauses.h" |
17 | #include "flang/Optimizer/Builder/FIRBuilder.h" |
18 | #include "flang/Optimizer/Dialect/FIRType.h" |
19 | #include "flang/Parser/parse-tree.h" |
20 | #include "flang/Semantics/symbol.h" |
21 | #include "flang/Semantics/type.h" |
22 | #include "mlir/IR/Location.h" |
23 | #include "mlir/IR/Types.h" |
24 | |
25 | namespace mlir { |
26 | namespace omp { |
27 | class DeclareReductionOp; |
28 | } // namespace omp |
29 | } // namespace mlir |
30 | |
31 | namespace Fortran { |
32 | namespace lower { |
33 | class AbstractConverter; |
34 | } // namespace lower |
35 | } // namespace Fortran |
36 | |
37 | namespace Fortran { |
38 | namespace lower { |
39 | namespace omp { |
40 | |
41 | class ReductionProcessor { |
42 | public: |
43 | // TODO: Move this enumeration to the OpenMP dialect |
44 | enum ReductionIdentifier { |
45 | ID, |
46 | USER_DEF_OP, |
47 | ADD, |
48 | SUBTRACT, |
49 | MULTIPLY, |
50 | AND, |
51 | OR, |
52 | EQV, |
53 | NEQV, |
54 | MAX, |
55 | MIN, |
56 | IAND, |
57 | IOR, |
58 | IEOR |
59 | }; |
60 | |
61 | static ReductionIdentifier |
62 | getReductionType(const omp::clause::ProcedureDesignator &pd); |
63 | |
64 | static ReductionIdentifier |
65 | getReductionType(omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp); |
66 | |
67 | static bool |
68 | supportedIntrinsicProcReduction(const omp::clause::ProcedureDesignator &pd); |
69 | |
70 | static const Fortran::semantics::SourceName |
71 | getRealName(const Fortran::semantics::Symbol *symbol); |
72 | |
73 | static const Fortran::semantics::SourceName |
74 | getRealName(const omp::clause::ProcedureDesignator &pd); |
75 | |
76 | static bool |
77 | doReductionByRef(const llvm::SmallVectorImpl<mlir::Value> &reductionVars); |
78 | |
79 | static std::string getReductionName(llvm::StringRef name, |
80 | const fir::KindMapping &kindMap, |
81 | mlir::Type ty, bool isByRef); |
82 | |
83 | static std::string |
84 | getReductionName(omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp, |
85 | const fir::KindMapping &kindMap, mlir::Type ty, |
86 | bool isByRef); |
87 | |
88 | /// This function returns the identity value of the operator \p |
89 | /// reductionOpName. For example: |
90 | /// 0 + x = x, |
91 | /// 1 * x = x |
92 | static int getOperationIdentity(ReductionIdentifier redId, |
93 | mlir::Location loc); |
94 | |
95 | static mlir::Value getReductionInitValue(mlir::Location loc, mlir::Type type, |
96 | ReductionIdentifier redId, |
97 | fir::FirOpBuilder &builder); |
98 | |
99 | template <typename FloatOp, typename IntegerOp> |
100 | static mlir::Value getReductionOperation(fir::FirOpBuilder &builder, |
101 | mlir::Type type, mlir::Location loc, |
102 | mlir::Value op1, mlir::Value op2); |
103 | template <typename FloatOp, typename IntegerOp, typename ComplexOp> |
104 | static mlir::Value getReductionOperation(fir::FirOpBuilder &builder, |
105 | mlir::Type type, mlir::Location loc, |
106 | mlir::Value op1, mlir::Value op2); |
107 | |
108 | static mlir::Value createScalarCombiner(fir::FirOpBuilder &builder, |
109 | mlir::Location loc, |
110 | ReductionIdentifier redId, |
111 | mlir::Type type, mlir::Value op1, |
112 | mlir::Value op2); |
113 | |
114 | /// Creates an OpenMP reduction declaration and inserts it into the provided |
115 | /// symbol table. The declaration has a constant initializer with the neutral |
116 | /// value `initValue`, and the reduction combiner carried over from `reduce`. |
117 | /// TODO: add atomic region. |
118 | static mlir::omp::DeclareReductionOp |
119 | createDeclareReduction(fir::FirOpBuilder &builder, |
120 | llvm::StringRef reductionOpName, |
121 | const ReductionIdentifier redId, mlir::Type type, |
122 | mlir::Location loc, bool isByRef); |
123 | |
124 | /// Creates a reduction declaration and associates it with an OpenMP block |
125 | /// directive. |
126 | static void addDeclareReduction( |
127 | mlir::Location currentLocation, |
128 | Fortran::lower::AbstractConverter &converter, |
129 | const omp::clause::Reduction &reduction, |
130 | llvm::SmallVectorImpl<mlir::Value> &reductionVars, |
131 | llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols, |
132 | llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> |
133 | *reductionSymbols = nullptr); |
134 | }; |
135 | |
136 | template <typename FloatOp, typename IntegerOp> |
137 | mlir::Value |
138 | ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder, |
139 | mlir::Type type, mlir::Location loc, |
140 | mlir::Value op1, mlir::Value op2) { |
141 | type = fir::unwrapRefType(type); |
142 | assert(type.isIntOrIndexOrFloat() && |
143 | "only integer, float and complex types are currently supported" ); |
144 | if (type.isIntOrIndex()) |
145 | return builder.create<IntegerOp>(loc, op1, op2); |
146 | return builder.create<FloatOp>(loc, op1, op2); |
147 | } |
148 | |
149 | template <typename FloatOp, typename IntegerOp, typename ComplexOp> |
150 | mlir::Value |
151 | ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder, |
152 | mlir::Type type, mlir::Location loc, |
153 | mlir::Value op1, mlir::Value op2) { |
154 | assert((type.isIntOrIndexOrFloat() || fir::isa_complex(type)) && |
155 | "only integer, float and complex types are currently supported" ); |
156 | if (type.isIntOrIndex()) |
157 | return builder.create<IntegerOp>(loc, op1, op2); |
158 | if (fir::isa_real(type)) |
159 | return builder.create<FloatOp>(loc, op1, op2); |
160 | return builder.create<ComplexOp>(loc, op1, op2); |
161 | } |
162 | |
163 | } // namespace omp |
164 | } // namespace lower |
165 | } // namespace Fortran |
166 | |
167 | #endif // FORTRAN_LOWER_REDUCTIONPROCESSOR_H |
168 | |