1 | //===-- Lower/OpenMP/ReductionProcessor.h -----------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #ifndef FORTRAN_LOWER_REDUCTIONPROCESSOR_H |
14 | #define FORTRAN_LOWER_REDUCTIONPROCESSOR_H |
15 | |
16 | #include "Clauses.h" |
17 | #include "flang/Optimizer/Builder/FIRBuilder.h" |
18 | #include "flang/Optimizer/Dialect/FIRType.h" |
19 | #include "flang/Parser/parse-tree.h" |
20 | #include "flang/Semantics/symbol.h" |
21 | #include "flang/Semantics/type.h" |
22 | #include "mlir/Dialect/OpenMP/OpenMPDialect.h" |
23 | #include "mlir/IR/Location.h" |
24 | #include "mlir/IR/Types.h" |
25 | |
26 | namespace mlir { |
27 | namespace omp { |
28 | class DeclareReductionOp; |
29 | } // namespace omp |
30 | } // namespace mlir |
31 | |
32 | namespace Fortran { |
33 | namespace lower { |
34 | class AbstractConverter; |
35 | } // namespace lower |
36 | } // namespace Fortran |
37 | |
38 | namespace Fortran { |
39 | namespace lower { |
40 | namespace omp { |
41 | |
42 | class ReductionProcessor { |
43 | public: |
44 | // TODO: Move this enumeration to the OpenMP dialect |
45 | enum ReductionIdentifier { |
46 | ID, |
47 | USER_DEF_OP, |
48 | ADD, |
49 | SUBTRACT, |
50 | MULTIPLY, |
51 | AND, |
52 | OR, |
53 | EQV, |
54 | NEQV, |
55 | MAX, |
56 | MIN, |
57 | IAND, |
58 | IOR, |
59 | IEOR |
60 | }; |
61 | |
62 | static ReductionIdentifier |
63 | getReductionType(const omp::clause::ProcedureDesignator &pd); |
64 | |
65 | static ReductionIdentifier |
66 | getReductionType(omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp); |
67 | |
68 | static bool |
69 | supportedIntrinsicProcReduction(const omp::clause::ProcedureDesignator &pd); |
70 | |
71 | static const semantics::SourceName |
72 | getRealName(const semantics::Symbol *symbol); |
73 | |
74 | static const semantics::SourceName |
75 | getRealName(const omp::clause::ProcedureDesignator &pd); |
76 | |
77 | static std::string getReductionName(llvm::StringRef name, |
78 | const fir::KindMapping &kindMap, |
79 | mlir::Type ty, bool isByRef); |
80 | |
81 | static std::string |
82 | getReductionName(omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp, |
83 | const fir::KindMapping &kindMap, mlir::Type ty, |
84 | bool isByRef); |
85 | |
86 | /// This function returns the identity value of the operator \p |
87 | /// reductionOpName. For example: |
88 | /// 0 + x = x, |
89 | /// 1 * x = x |
90 | static int getOperationIdentity(ReductionIdentifier redId, |
91 | mlir::Location loc); |
92 | |
93 | static mlir::Value getReductionInitValue(mlir::Location loc, mlir::Type type, |
94 | ReductionIdentifier redId, |
95 | fir::FirOpBuilder &builder); |
96 | |
97 | template <typename FloatOp, typename IntegerOp> |
98 | static mlir::Value getReductionOperation(fir::FirOpBuilder &builder, |
99 | mlir::Type type, mlir::Location loc, |
100 | mlir::Value op1, mlir::Value op2); |
101 | template <typename FloatOp, typename IntegerOp, typename ComplexOp> |
102 | static mlir::Value getReductionOperation(fir::FirOpBuilder &builder, |
103 | mlir::Type type, mlir::Location loc, |
104 | mlir::Value op1, mlir::Value op2); |
105 | |
106 | static mlir::Value createScalarCombiner(fir::FirOpBuilder &builder, |
107 | mlir::Location loc, |
108 | ReductionIdentifier redId, |
109 | mlir::Type type, mlir::Value op1, |
110 | mlir::Value op2); |
111 | |
112 | /// Creates an OpenMP reduction declaration and inserts it into the provided |
113 | /// symbol table. The declaration has a constant initializer with the neutral |
114 | /// value `initValue`, and the reduction combiner carried over from `reduce`. |
115 | /// TODO: add atomic region. |
116 | static mlir::omp::DeclareReductionOp |
117 | createDeclareReduction(AbstractConverter &builder, |
118 | llvm::StringRef reductionOpName, |
119 | const ReductionIdentifier redId, mlir::Type type, |
120 | mlir::Location loc, bool isByRef); |
121 | |
122 | /// Creates a reduction declaration and associates it with an OpenMP block |
123 | /// directive. |
124 | template <class T> |
125 | static void processReductionArguments( |
126 | mlir::Location currentLocation, lower::AbstractConverter &converter, |
127 | const T &reduction, llvm::SmallVectorImpl<mlir::Value> &reductionVars, |
128 | llvm::SmallVectorImpl<bool> &reduceVarByRef, |
129 | llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols, |
130 | llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols, |
131 | mlir::omp::ReductionModifierAttr *reductionMod = nullptr); |
132 | }; |
133 | |
134 | template <typename FloatOp, typename IntegerOp> |
135 | mlir::Value |
136 | ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder, |
137 | mlir::Type type, mlir::Location loc, |
138 | mlir::Value op1, mlir::Value op2) { |
139 | type = fir::unwrapRefType(type); |
140 | assert(type.isIntOrIndexOrFloat() && |
141 | "only integer, float and complex types are currently supported" ); |
142 | if (type.isIntOrIndex()) |
143 | return builder.create<IntegerOp>(loc, op1, op2); |
144 | return builder.create<FloatOp>(loc, op1, op2); |
145 | } |
146 | |
147 | template <typename FloatOp, typename IntegerOp, typename ComplexOp> |
148 | mlir::Value |
149 | ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder, |
150 | mlir::Type type, mlir::Location loc, |
151 | mlir::Value op1, mlir::Value op2) { |
152 | assert((type.isIntOrIndexOrFloat() || fir::isa_complex(type)) && |
153 | "only integer, float and complex types are currently supported" ); |
154 | if (type.isIntOrIndex()) |
155 | return builder.create<IntegerOp>(loc, op1, op2); |
156 | if (fir::isa_real(type)) |
157 | return builder.create<FloatOp>(loc, op1, op2); |
158 | return builder.create<ComplexOp>(loc, op1, op2); |
159 | } |
160 | |
161 | } // namespace omp |
162 | } // namespace lower |
163 | } // namespace Fortran |
164 | |
165 | #endif // FORTRAN_LOWER_REDUCTIONPROCESSOR_H |
166 | |