1//===-- Lower/OpenMP/ReductionProcessor.h -----------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef FORTRAN_LOWER_REDUCTIONPROCESSOR_H
14#define FORTRAN_LOWER_REDUCTIONPROCESSOR_H
15
16#include "Clauses.h"
17#include "flang/Optimizer/Builder/FIRBuilder.h"
18#include "flang/Optimizer/Dialect/FIRType.h"
19#include "flang/Parser/parse-tree.h"
20#include "flang/Semantics/symbol.h"
21#include "flang/Semantics/type.h"
22#include "mlir/IR/Location.h"
23#include "mlir/IR/Types.h"
24
25namespace mlir {
26namespace omp {
27class DeclareReductionOp;
28} // namespace omp
29} // namespace mlir
30
31namespace Fortran {
32namespace lower {
33class AbstractConverter;
34} // namespace lower
35} // namespace Fortran
36
37namespace Fortran {
38namespace lower {
39namespace omp {
40
41class ReductionProcessor {
42public:
43 // TODO: Move this enumeration to the OpenMP dialect
44 enum ReductionIdentifier {
45 ID,
46 USER_DEF_OP,
47 ADD,
48 SUBTRACT,
49 MULTIPLY,
50 AND,
51 OR,
52 EQV,
53 NEQV,
54 MAX,
55 MIN,
56 IAND,
57 IOR,
58 IEOR
59 };
60
61 static ReductionIdentifier
62 getReductionType(const omp::clause::ProcedureDesignator &pd);
63
64 static ReductionIdentifier
65 getReductionType(omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp);
66
67 static bool
68 supportedIntrinsicProcReduction(const omp::clause::ProcedureDesignator &pd);
69
70 static const Fortran::semantics::SourceName
71 getRealName(const Fortran::semantics::Symbol *symbol);
72
73 static const Fortran::semantics::SourceName
74 getRealName(const omp::clause::ProcedureDesignator &pd);
75
76 static bool
77 doReductionByRef(const llvm::SmallVectorImpl<mlir::Value> &reductionVars);
78
79 static std::string getReductionName(llvm::StringRef name,
80 const fir::KindMapping &kindMap,
81 mlir::Type ty, bool isByRef);
82
83 static std::string
84 getReductionName(omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp,
85 const fir::KindMapping &kindMap, mlir::Type ty,
86 bool isByRef);
87
88 /// This function returns the identity value of the operator \p
89 /// reductionOpName. For example:
90 /// 0 + x = x,
91 /// 1 * x = x
92 static int getOperationIdentity(ReductionIdentifier redId,
93 mlir::Location loc);
94
95 static mlir::Value getReductionInitValue(mlir::Location loc, mlir::Type type,
96 ReductionIdentifier redId,
97 fir::FirOpBuilder &builder);
98
99 template <typename FloatOp, typename IntegerOp>
100 static mlir::Value getReductionOperation(fir::FirOpBuilder &builder,
101 mlir::Type type, mlir::Location loc,
102 mlir::Value op1, mlir::Value op2);
103 template <typename FloatOp, typename IntegerOp, typename ComplexOp>
104 static mlir::Value getReductionOperation(fir::FirOpBuilder &builder,
105 mlir::Type type, mlir::Location loc,
106 mlir::Value op1, mlir::Value op2);
107
108 static mlir::Value createScalarCombiner(fir::FirOpBuilder &builder,
109 mlir::Location loc,
110 ReductionIdentifier redId,
111 mlir::Type type, mlir::Value op1,
112 mlir::Value op2);
113
114 /// Creates an OpenMP reduction declaration and inserts it into the provided
115 /// symbol table. The declaration has a constant initializer with the neutral
116 /// value `initValue`, and the reduction combiner carried over from `reduce`.
117 /// TODO: add atomic region.
118 static mlir::omp::DeclareReductionOp
119 createDeclareReduction(fir::FirOpBuilder &builder,
120 llvm::StringRef reductionOpName,
121 const ReductionIdentifier redId, mlir::Type type,
122 mlir::Location loc, bool isByRef);
123
124 /// Creates a reduction declaration and associates it with an OpenMP block
125 /// directive.
126 static void addDeclareReduction(
127 mlir::Location currentLocation,
128 Fortran::lower::AbstractConverter &converter,
129 const omp::clause::Reduction &reduction,
130 llvm::SmallVectorImpl<mlir::Value> &reductionVars,
131 llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
132 llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
133 *reductionSymbols = nullptr);
134};
135
136template <typename FloatOp, typename IntegerOp>
137mlir::Value
138ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder,
139 mlir::Type type, mlir::Location loc,
140 mlir::Value op1, mlir::Value op2) {
141 type = fir::unwrapRefType(type);
142 assert(type.isIntOrIndexOrFloat() &&
143 "only integer, float and complex types are currently supported");
144 if (type.isIntOrIndex())
145 return builder.create<IntegerOp>(loc, op1, op2);
146 return builder.create<FloatOp>(loc, op1, op2);
147}
148
149template <typename FloatOp, typename IntegerOp, typename ComplexOp>
150mlir::Value
151ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder,
152 mlir::Type type, mlir::Location loc,
153 mlir::Value op1, mlir::Value op2) {
154 assert((type.isIntOrIndexOrFloat() || fir::isa_complex(type)) &&
155 "only integer, float and complex types are currently supported");
156 if (type.isIntOrIndex())
157 return builder.create<IntegerOp>(loc, op1, op2);
158 if (fir::isa_real(type))
159 return builder.create<FloatOp>(loc, op1, op2);
160 return builder.create<ComplexOp>(loc, op1, op2);
161}
162
163} // namespace omp
164} // namespace lower
165} // namespace Fortran
166
167#endif // FORTRAN_LOWER_REDUCTIONPROCESSOR_H
168

source code of flang/lib/Lower/OpenMP/ReductionProcessor.h