1//===-- Lower/OpenMP/ReductionProcessor.h -----------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef FORTRAN_LOWER_REDUCTIONPROCESSOR_H
14#define FORTRAN_LOWER_REDUCTIONPROCESSOR_H
15
16#include "Clauses.h"
17#include "flang/Optimizer/Builder/FIRBuilder.h"
18#include "flang/Optimizer/Dialect/FIRType.h"
19#include "flang/Parser/parse-tree.h"
20#include "flang/Semantics/symbol.h"
21#include "flang/Semantics/type.h"
22#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
23#include "mlir/IR/Location.h"
24#include "mlir/IR/Types.h"
25
26namespace mlir {
27namespace omp {
28class DeclareReductionOp;
29} // namespace omp
30} // namespace mlir
31
32namespace Fortran {
33namespace lower {
34class AbstractConverter;
35} // namespace lower
36} // namespace Fortran
37
38namespace Fortran {
39namespace lower {
40namespace omp {
41
42class ReductionProcessor {
43public:
44 // TODO: Move this enumeration to the OpenMP dialect
45 enum ReductionIdentifier {
46 ID,
47 USER_DEF_OP,
48 ADD,
49 SUBTRACT,
50 MULTIPLY,
51 AND,
52 OR,
53 EQV,
54 NEQV,
55 MAX,
56 MIN,
57 IAND,
58 IOR,
59 IEOR
60 };
61
62 static ReductionIdentifier
63 getReductionType(const omp::clause::ProcedureDesignator &pd);
64
65 static ReductionIdentifier
66 getReductionType(omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp);
67
68 static bool
69 supportedIntrinsicProcReduction(const omp::clause::ProcedureDesignator &pd);
70
71 static const semantics::SourceName
72 getRealName(const semantics::Symbol *symbol);
73
74 static const semantics::SourceName
75 getRealName(const omp::clause::ProcedureDesignator &pd);
76
77 static std::string getReductionName(llvm::StringRef name,
78 const fir::KindMapping &kindMap,
79 mlir::Type ty, bool isByRef);
80
81 static std::string
82 getReductionName(omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp,
83 const fir::KindMapping &kindMap, mlir::Type ty,
84 bool isByRef);
85
86 /// This function returns the identity value of the operator \p
87 /// reductionOpName. For example:
88 /// 0 + x = x,
89 /// 1 * x = x
90 static int getOperationIdentity(ReductionIdentifier redId,
91 mlir::Location loc);
92
93 static mlir::Value getReductionInitValue(mlir::Location loc, mlir::Type type,
94 ReductionIdentifier redId,
95 fir::FirOpBuilder &builder);
96
97 template <typename FloatOp, typename IntegerOp>
98 static mlir::Value getReductionOperation(fir::FirOpBuilder &builder,
99 mlir::Type type, mlir::Location loc,
100 mlir::Value op1, mlir::Value op2);
101 template <typename FloatOp, typename IntegerOp, typename ComplexOp>
102 static mlir::Value getReductionOperation(fir::FirOpBuilder &builder,
103 mlir::Type type, mlir::Location loc,
104 mlir::Value op1, mlir::Value op2);
105
106 static mlir::Value createScalarCombiner(fir::FirOpBuilder &builder,
107 mlir::Location loc,
108 ReductionIdentifier redId,
109 mlir::Type type, mlir::Value op1,
110 mlir::Value op2);
111
112 /// Creates an OpenMP reduction declaration and inserts it into the provided
113 /// symbol table. The declaration has a constant initializer with the neutral
114 /// value `initValue`, and the reduction combiner carried over from `reduce`.
115 /// TODO: add atomic region.
116 static mlir::omp::DeclareReductionOp
117 createDeclareReduction(AbstractConverter &builder,
118 llvm::StringRef reductionOpName,
119 const ReductionIdentifier redId, mlir::Type type,
120 mlir::Location loc, bool isByRef);
121
122 /// Creates a reduction declaration and associates it with an OpenMP block
123 /// directive.
124 template <class T>
125 static void processReductionArguments(
126 mlir::Location currentLocation, lower::AbstractConverter &converter,
127 const T &reduction, llvm::SmallVectorImpl<mlir::Value> &reductionVars,
128 llvm::SmallVectorImpl<bool> &reduceVarByRef,
129 llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
130 llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols,
131 mlir::omp::ReductionModifierAttr *reductionMod = nullptr);
132};
133
134template <typename FloatOp, typename IntegerOp>
135mlir::Value
136ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder,
137 mlir::Type type, mlir::Location loc,
138 mlir::Value op1, mlir::Value op2) {
139 type = fir::unwrapRefType(type);
140 assert(type.isIntOrIndexOrFloat() &&
141 "only integer, float and complex types are currently supported");
142 if (type.isIntOrIndex())
143 return builder.create<IntegerOp>(loc, op1, op2);
144 return builder.create<FloatOp>(loc, op1, op2);
145}
146
147template <typename FloatOp, typename IntegerOp, typename ComplexOp>
148mlir::Value
149ReductionProcessor::getReductionOperation(fir::FirOpBuilder &builder,
150 mlir::Type type, mlir::Location loc,
151 mlir::Value op1, mlir::Value op2) {
152 assert((type.isIntOrIndexOrFloat() || fir::isa_complex(type)) &&
153 "only integer, float and complex types are currently supported");
154 if (type.isIntOrIndex())
155 return builder.create<IntegerOp>(loc, op1, op2);
156 if (fir::isa_real(type))
157 return builder.create<FloatOp>(loc, op1, op2);
158 return builder.create<ComplexOp>(loc, op1, op2);
159}
160
161} // namespace omp
162} // namespace lower
163} // namespace Fortran
164
165#endif // FORTRAN_LOWER_REDUCTIONPROCESSOR_H
166

source code of flang/lib/Lower/OpenMP/ReductionProcessor.h