1//===-- ReductionProcessor.cpp ----------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10//
11//===----------------------------------------------------------------------===//
12
13#include "flang/Lower/Support/ReductionProcessor.h"
14
15#include "flang/Lower/AbstractConverter.h"
16#include "flang/Lower/ConvertType.h"
17#include "flang/Lower/OpenMP/Clauses.h"
18#include "flang/Lower/Support/PrivateReductionUtils.h"
19#include "flang/Lower/SymbolMap.h"
20#include "flang/Optimizer/Builder/Complex.h"
21#include "flang/Optimizer/Builder/HLFIRTools.h"
22#include "flang/Optimizer/Builder/Todo.h"
23#include "flang/Optimizer/Dialect/FIRType.h"
24#include "flang/Optimizer/HLFIR/HLFIROps.h"
25#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
26#include "llvm/Support/CommandLine.h"
27#include <type_traits>
28
29static llvm::cl::opt<bool> forceByrefReduction(
30 "force-byref-reduction",
31 llvm::cl::desc("Pass all reduction arguments by reference"),
32 llvm::cl::Hidden);
33
34using ReductionModifier =
35 Fortran::lower::omp::clause::Reduction::ReductionModifier;
36
37namespace Fortran {
38namespace lower {
39namespace omp {
40
41// explicit template declarations
42template bool ReductionProcessor::processReductionArguments<
43 mlir::omp::DeclareReductionOp, omp::clause::ReductionOperatorList>(
44 mlir::Location currentLocation, lower::AbstractConverter &converter,
45 const omp::clause::ReductionOperatorList &redOperatorList,
46 llvm::SmallVectorImpl<mlir::Value> &reductionVars,
47 llvm::SmallVectorImpl<bool> &reduceVarByRef,
48 llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
49 const llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols);
50
51template bool ReductionProcessor::processReductionArguments<
52 fir::DeclareReductionOp, llvm::SmallVector<fir::ReduceOperationEnum>>(
53 mlir::Location currentLocation, lower::AbstractConverter &converter,
54 const llvm::SmallVector<fir::ReduceOperationEnum> &redOperatorList,
55 llvm::SmallVectorImpl<mlir::Value> &reductionVars,
56 llvm::SmallVectorImpl<bool> &reduceVarByRef,
57 llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
58 const llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols);
59
60template mlir::omp::DeclareReductionOp
61ReductionProcessor::createDeclareReduction<mlir::omp::DeclareReductionOp>(
62 AbstractConverter &converter, llvm::StringRef reductionOpName,
63 const ReductionIdentifier redId, mlir::Type type, mlir::Location loc,
64 bool isByRef);
65
66template fir::DeclareReductionOp
67ReductionProcessor::createDeclareReduction<fir::DeclareReductionOp>(
68 AbstractConverter &converter, llvm::StringRef reductionOpName,
69 const ReductionIdentifier redId, mlir::Type type, mlir::Location loc,
70 bool isByRef);
71
72ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
73 const omp::clause::ProcedureDesignator &pd) {
74 auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>(
75 getRealName(pd.v.sym()).ToString())
76 .Case("max", ReductionIdentifier::MAX)
77 .Case("min", ReductionIdentifier::MIN)
78 .Case("iand", ReductionIdentifier::IAND)
79 .Case("ior", ReductionIdentifier::IOR)
80 .Case("ieor", ReductionIdentifier::IEOR)
81 .Default(std::nullopt);
82 assert(redType && "Invalid Reduction");
83 return *redType;
84}
85
86ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
87 omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp) {
88 switch (intrinsicOp) {
89 case omp::clause::DefinedOperator::IntrinsicOperator::Add:
90 return ReductionIdentifier::ADD;
91 case omp::clause::DefinedOperator::IntrinsicOperator::Subtract:
92 return ReductionIdentifier::SUBTRACT;
93 case omp::clause::DefinedOperator::IntrinsicOperator::Multiply:
94 return ReductionIdentifier::MULTIPLY;
95 case omp::clause::DefinedOperator::IntrinsicOperator::AND:
96 return ReductionIdentifier::AND;
97 case omp::clause::DefinedOperator::IntrinsicOperator::EQV:
98 return ReductionIdentifier::EQV;
99 case omp::clause::DefinedOperator::IntrinsicOperator::OR:
100 return ReductionIdentifier::OR;
101 case omp::clause::DefinedOperator::IntrinsicOperator::NEQV:
102 return ReductionIdentifier::NEQV;
103 default:
104 llvm_unreachable("unexpected intrinsic operator in reduction");
105 }
106}
107
108ReductionProcessor::ReductionIdentifier
109ReductionProcessor::getReductionType(const fir::ReduceOperationEnum &redOp) {
110 switch (redOp) {
111 case fir::ReduceOperationEnum::Add:
112 return ReductionIdentifier::ADD;
113 case fir::ReduceOperationEnum::Multiply:
114 return ReductionIdentifier::MULTIPLY;
115
116 case fir::ReduceOperationEnum::AND:
117 return ReductionIdentifier::AND;
118 case fir::ReduceOperationEnum::OR:
119 return ReductionIdentifier::OR;
120
121 case fir::ReduceOperationEnum::EQV:
122 return ReductionIdentifier::EQV;
123 case fir::ReduceOperationEnum::NEQV:
124 return ReductionIdentifier::NEQV;
125
126 case fir::ReduceOperationEnum::IAND:
127 return ReductionIdentifier::IAND;
128 case fir::ReduceOperationEnum::IEOR:
129 return ReductionIdentifier::IEOR;
130 case fir::ReduceOperationEnum::IOR:
131 return ReductionIdentifier::IOR;
132 case fir::ReduceOperationEnum::MAX:
133 return ReductionIdentifier::MAX;
134 case fir::ReduceOperationEnum::MIN:
135 return ReductionIdentifier::MIN;
136 }
137 llvm_unreachable("Unhandled ReductionIdentifier case");
138}
139
140bool ReductionProcessor::supportedIntrinsicProcReduction(
141 const omp::clause::ProcedureDesignator &pd) {
142 semantics::Symbol *sym = pd.v.sym();
143 if (!sym->GetUltimate().attrs().test(semantics::Attr::INTRINSIC))
144 return false;
145 auto redType = llvm::StringSwitch<bool>(getRealName(sym).ToString())
146 .Case("max", true)
147 .Case("min", true)
148 .Case("iand", true)
149 .Case("ior", true)
150 .Case("ieor", true)
151 .Default(false);
152 return redType;
153}
154
155std::string
156ReductionProcessor::getReductionName(llvm::StringRef name,
157 const fir::KindMapping &kindMap,
158 mlir::Type ty, bool isByRef) {
159 ty = fir::unwrapRefType(ty);
160
161 // extra string to distinguish reduction functions for variables passed by
162 // reference
163 llvm::StringRef byrefAddition{""};
164 if (isByRef)
165 byrefAddition = "_byref";
166
167 return fir::getTypeAsString(ty, kindMap, (name + byrefAddition).str());
168}
169
170std::string
171ReductionProcessor::getReductionName(ReductionIdentifier redId,
172 const fir::KindMapping &kindMap,
173 mlir::Type ty, bool isByRef) {
174 std::string reductionName;
175
176 switch (redId) {
177 case ReductionIdentifier::ADD:
178 reductionName = "add_reduction";
179 break;
180 case ReductionIdentifier::MULTIPLY:
181 reductionName = "multiply_reduction";
182 break;
183 case ReductionIdentifier::AND:
184 reductionName = "and_reduction";
185 break;
186 case ReductionIdentifier::EQV:
187 reductionName = "eqv_reduction";
188 break;
189 case ReductionIdentifier::OR:
190 reductionName = "or_reduction";
191 break;
192 case ReductionIdentifier::NEQV:
193 reductionName = "neqv_reduction";
194 break;
195 default:
196 reductionName = "other_reduction";
197 break;
198 }
199
200 return getReductionName(reductionName, kindMap, ty, isByRef);
201}
202
203mlir::Value
204ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type,
205 ReductionIdentifier redId,
206 fir::FirOpBuilder &builder) {
207 type = fir::unwrapRefType(type);
208 if (!fir::isa_integer(type) && !fir::isa_real(type) &&
209 !fir::isa_complex(type) && !mlir::isa<fir::LogicalType>(type))
210 TODO(loc, "Reduction of some types is not supported");
211 switch (redId) {
212 case ReductionIdentifier::MAX: {
213 if (auto ty = mlir::dyn_cast<mlir::FloatType>(type)) {
214 const llvm::fltSemantics &sem = ty.getFloatSemantics();
215 return builder.createRealConstant(
216 loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true));
217 }
218 unsigned bits = type.getIntOrFloatBitWidth();
219 int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue();
220 return builder.createIntegerConstant(loc, type, minInt);
221 }
222 case ReductionIdentifier::MIN: {
223 if (auto ty = mlir::dyn_cast<mlir::FloatType>(type)) {
224 const llvm::fltSemantics &sem = ty.getFloatSemantics();
225 return builder.createRealConstant(
226 loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/false));
227 }
228 unsigned bits = type.getIntOrFloatBitWidth();
229 int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue();
230 return builder.createIntegerConstant(loc, type, maxInt);
231 }
232 case ReductionIdentifier::IOR: {
233 unsigned bits = type.getIntOrFloatBitWidth();
234 int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
235 return builder.createIntegerConstant(loc, type, zeroInt);
236 }
237 case ReductionIdentifier::IEOR: {
238 unsigned bits = type.getIntOrFloatBitWidth();
239 int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
240 return builder.createIntegerConstant(loc, type, zeroInt);
241 }
242 case ReductionIdentifier::IAND: {
243 unsigned bits = type.getIntOrFloatBitWidth();
244 int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue();
245 return builder.createIntegerConstant(loc, type, allOnInt);
246 }
247 case ReductionIdentifier::ADD:
248 case ReductionIdentifier::MULTIPLY:
249 case ReductionIdentifier::AND:
250 case ReductionIdentifier::OR:
251 case ReductionIdentifier::EQV:
252 case ReductionIdentifier::NEQV:
253 if (auto cplxTy = mlir::dyn_cast<mlir::ComplexType>(type)) {
254 mlir::Type realTy = cplxTy.getElementType();
255 mlir::Value initRe = builder.createRealConstant(
256 loc, realTy, getOperationIdentity(redId, loc));
257 mlir::Value initIm = builder.createRealConstant(loc, realTy, 0);
258
259 return fir::factory::Complex{builder, loc}.createComplex(type, initRe,
260 initIm);
261 }
262 if (mlir::isa<mlir::FloatType>(type))
263 return builder.create<mlir::arith::ConstantOp>(
264 loc, type,
265 builder.getFloatAttr(type, (double)getOperationIdentity(redId, loc)));
266
267 if (mlir::isa<fir::LogicalType>(type)) {
268 mlir::Value intConst = builder.create<mlir::arith::ConstantOp>(
269 loc, builder.getI1Type(),
270 builder.getIntegerAttr(builder.getI1Type(),
271 getOperationIdentity(redId, loc)));
272 return builder.createConvert(loc, type, intConst);
273 }
274
275 return builder.create<mlir::arith::ConstantOp>(
276 loc, type,
277 builder.getIntegerAttr(type, getOperationIdentity(redId, loc)));
278 case ReductionIdentifier::ID:
279 case ReductionIdentifier::USER_DEF_OP:
280 case ReductionIdentifier::SUBTRACT:
281 TODO(loc, "Reduction of some identifier types is not supported");
282 }
283 llvm_unreachable("Unhandled Reduction identifier : getReductionInitValue");
284}
285
286mlir::Value ReductionProcessor::createScalarCombiner(
287 fir::FirOpBuilder &builder, mlir::Location loc, ReductionIdentifier redId,
288 mlir::Type type, mlir::Value op1, mlir::Value op2) {
289 mlir::Value reductionOp;
290 type = fir::unwrapRefType(type);
291 switch (redId) {
292 case ReductionIdentifier::MAX:
293 reductionOp =
294 getReductionOperation<mlir::arith::MaxNumFOp, mlir::arith::MaxSIOp>(
295 builder, type, loc, op1, op2);
296 break;
297 case ReductionIdentifier::MIN:
298 reductionOp =
299 getReductionOperation<mlir::arith::MinNumFOp, mlir::arith::MinSIOp>(
300 builder, type, loc, op1, op2);
301 break;
302 case ReductionIdentifier::IOR:
303 assert((type.isIntOrIndex()) && "only integer is expected");
304 reductionOp = builder.create<mlir::arith::OrIOp>(loc, op1, op2);
305 break;
306 case ReductionIdentifier::IEOR:
307 assert((type.isIntOrIndex()) && "only integer is expected");
308 reductionOp = builder.create<mlir::arith::XOrIOp>(loc, op1, op2);
309 break;
310 case ReductionIdentifier::IAND:
311 assert((type.isIntOrIndex()) && "only integer is expected");
312 reductionOp = builder.create<mlir::arith::AndIOp>(loc, op1, op2);
313 break;
314 case ReductionIdentifier::ADD:
315 reductionOp =
316 getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp,
317 fir::AddcOp>(builder, type, loc, op1, op2);
318 break;
319 case ReductionIdentifier::MULTIPLY:
320 reductionOp =
321 getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp,
322 fir::MulcOp>(builder, type, loc, op1, op2);
323 break;
324 case ReductionIdentifier::AND: {
325 mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
326 mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
327
328 mlir::Value andiOp = builder.create<mlir::arith::AndIOp>(loc, op1I1, op2I1);
329
330 reductionOp = builder.createConvert(loc, type, andiOp);
331 break;
332 }
333 case ReductionIdentifier::OR: {
334 mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
335 mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
336
337 mlir::Value oriOp = builder.create<mlir::arith::OrIOp>(loc, op1I1, op2I1);
338
339 reductionOp = builder.createConvert(loc, type, oriOp);
340 break;
341 }
342 case ReductionIdentifier::EQV: {
343 mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
344 mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
345
346 mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>(
347 loc, mlir::arith::CmpIPredicate::eq, op1I1, op2I1);
348
349 reductionOp = builder.createConvert(loc, type, cmpiOp);
350 break;
351 }
352 case ReductionIdentifier::NEQV: {
353 mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
354 mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
355
356 mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>(
357 loc, mlir::arith::CmpIPredicate::ne, op1I1, op2I1);
358
359 reductionOp = builder.createConvert(loc, type, cmpiOp);
360 break;
361 }
362 default:
363 TODO(loc, "Reduction of some intrinsic operators is not supported");
364 }
365
366 return reductionOp;
367}
368
369template <typename ParentDeclOpType>
370static void genYield(fir::FirOpBuilder &builder, mlir::Location loc,
371 mlir::Value yieldedValue) {
372 if constexpr (std::is_same_v<ParentDeclOpType, mlir::omp::DeclareReductionOp>)
373 builder.create<mlir::omp::YieldOp>(loc, yieldedValue);
374 else
375 builder.create<fir::YieldOp>(loc, yieldedValue);
376}
377
378/// Create reduction combiner region for reduction variables which are boxed
379/// arrays
380template <typename DeclRedOpType>
381static void genBoxCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
382 ReductionProcessor::ReductionIdentifier redId,
383 fir::BaseBoxType boxTy, mlir::Value lhs,
384 mlir::Value rhs) {
385 fir::SequenceType seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(
386 fir::unwrapRefType(boxTy.getEleTy()));
387 fir::HeapType heapTy =
388 mlir::dyn_cast_or_null<fir::HeapType>(boxTy.getEleTy());
389 fir::PointerType ptrTy =
390 mlir::dyn_cast_or_null<fir::PointerType>(boxTy.getEleTy());
391 if ((!seqTy || seqTy.hasUnknownShape()) && !heapTy && !ptrTy)
392 TODO(loc, "Unsupported boxed type in OpenMP reduction");
393
394 // load fir.ref<fir.box<...>>
395 mlir::Value lhsAddr = lhs;
396 lhs = builder.create<fir::LoadOp>(loc, lhs);
397 rhs = builder.create<fir::LoadOp>(loc, rhs);
398
399 if ((heapTy || ptrTy) && !seqTy) {
400 // get box contents (heap pointers)
401 lhs = builder.create<fir::BoxAddrOp>(loc, lhs);
402 rhs = builder.create<fir::BoxAddrOp>(loc, rhs);
403 mlir::Value lhsValAddr = lhs;
404
405 // load heap pointers
406 lhs = builder.create<fir::LoadOp>(loc, lhs);
407 rhs = builder.create<fir::LoadOp>(loc, rhs);
408
409 mlir::Type eleTy = heapTy ? heapTy.getEleTy() : ptrTy.getEleTy();
410
411 mlir::Value result = ReductionProcessor::createScalarCombiner(
412 builder, loc, redId, eleTy, lhs, rhs);
413 builder.create<fir::StoreOp>(loc, result, lhsValAddr);
414 genYield<DeclRedOpType>(builder, loc, lhsAddr);
415 return;
416 }
417
418 // Get ShapeShift with default lower bounds. This makes it possible to use
419 // unmodified LoopNest's indices with ArrayCoorOp.
420 fir::ShapeShiftOp shapeShift =
421 getShapeShift(builder, loc, lhs,
422 /*cannotHaveNonDefaultLowerBounds=*/false,
423 /*useDefaultLowerBounds=*/true);
424
425 // Iterate over array elements, applying the equivalent scalar reduction:
426
427 // F2018 5.4.10.2: Unallocated allocatable variables may not be referenced
428 // and so no null check is needed here before indexing into the (possibly
429 // allocatable) arrays.
430
431 // A hlfir::elemental here gets inlined with a temporary so create the
432 // loop nest directly.
433 // This function already controls all of the code in this region so we
434 // know this won't miss any opportuinties for clever elemental inlining
435 hlfir::LoopNest nest = hlfir::genLoopNest(
436 loc, builder, shapeShift.getExtents(), /*isUnordered=*/true);
437 builder.setInsertionPointToStart(nest.body);
438 const bool seqIsVolatile = fir::isa_volatile_type(seqTy.getEleTy());
439 mlir::Type refTy = fir::ReferenceType::get(seqTy.getEleTy(), seqIsVolatile);
440 auto lhsEleAddr = builder.create<fir::ArrayCoorOp>(
441 loc, refTy, lhs, shapeShift, /*slice=*/mlir::Value{},
442 nest.oneBasedIndices, /*typeparms=*/mlir::ValueRange{});
443 auto rhsEleAddr = builder.create<fir::ArrayCoorOp>(
444 loc, refTy, rhs, shapeShift, /*slice=*/mlir::Value{},
445 nest.oneBasedIndices, /*typeparms=*/mlir::ValueRange{});
446 auto lhsEle = builder.create<fir::LoadOp>(loc, lhsEleAddr);
447 auto rhsEle = builder.create<fir::LoadOp>(loc, rhsEleAddr);
448 mlir::Value scalarReduction = ReductionProcessor::createScalarCombiner(
449 builder, loc, redId, refTy, lhsEle, rhsEle);
450 builder.create<fir::StoreOp>(loc, scalarReduction, lhsEleAddr);
451
452 builder.setInsertionPointAfter(nest.outerOp);
453 genYield<DeclRedOpType>(builder, loc, lhsAddr);
454}
455
456// generate combiner region for reduction operations
457template <typename DeclRedOpType>
458static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
459 ReductionProcessor::ReductionIdentifier redId,
460 mlir::Type ty, mlir::Value lhs, mlir::Value rhs,
461 bool isByRef) {
462 ty = fir::unwrapRefType(ty);
463
464 if (fir::isa_trivial(ty)) {
465 mlir::Value lhsLoaded = builder.loadIfRef(loc, lhs);
466 mlir::Value rhsLoaded = builder.loadIfRef(loc, rhs);
467
468 mlir::Value result = ReductionProcessor::createScalarCombiner(
469 builder, loc, redId, ty, lhsLoaded, rhsLoaded);
470 if (isByRef) {
471 builder.create<fir::StoreOp>(loc, result, lhs);
472 genYield<DeclRedOpType>(builder, loc, lhs);
473 } else {
474 genYield<DeclRedOpType>(builder, loc, result);
475 }
476 return;
477 }
478 // all arrays should have been boxed
479 if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) {
480 genBoxCombiner<DeclRedOpType>(builder, loc, redId, boxTy, lhs, rhs);
481 return;
482 }
483
484 TODO(loc, "OpenMP genCombiner for unsupported reduction variable type");
485}
486
487// like fir::unwrapSeqOrBoxedSeqType except it also works for non-sequence boxes
488static mlir::Type unwrapSeqOrBoxedType(mlir::Type ty) {
489 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty))
490 return seqTy.getEleTy();
491 if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) {
492 auto eleTy = fir::unwrapRefType(boxTy.getEleTy());
493 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(eleTy))
494 return seqTy.getEleTy();
495 return eleTy;
496 }
497 return ty;
498}
499
500template <typename OpType>
501static void createReductionAllocAndInitRegions(
502 AbstractConverter &converter, mlir::Location loc, OpType &reductionDecl,
503 const ReductionProcessor::ReductionIdentifier redId, mlir::Type type,
504 bool isByRef) {
505 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
506 auto yield = [&](mlir::Value ret) { genYield<OpType>(builder, loc, ret); };
507
508 mlir::Block *allocBlock = nullptr;
509 mlir::Block *initBlock = nullptr;
510 if (isByRef) {
511 allocBlock =
512 builder.createBlock(&reductionDecl.getAllocRegion(),
513 reductionDecl.getAllocRegion().end(), {}, {});
514 initBlock = builder.createBlock(&reductionDecl.getInitializerRegion(),
515 reductionDecl.getInitializerRegion().end(),
516 {type, type}, {loc, loc});
517 } else {
518 initBlock = builder.createBlock(&reductionDecl.getInitializerRegion(),
519 reductionDecl.getInitializerRegion().end(),
520 {type}, {loc});
521 }
522
523 mlir::Type ty = fir::unwrapRefType(type);
524 builder.setInsertionPointToEnd(initBlock);
525 mlir::Value initValue = ReductionProcessor::getReductionInitValue(
526 loc, unwrapSeqOrBoxedType(ty), redId, builder);
527
528 if (isByRef) {
529 populateByRefInitAndCleanupRegions(
530 converter, loc, type, initValue, initBlock,
531 reductionDecl.getInitializerAllocArg(),
532 reductionDecl.getInitializerMoldArg(), reductionDecl.getCleanupRegion(),
533 DeclOperationKind::Reduction, /*sym=*/nullptr,
534 /*cannotHaveLowerBounds=*/false,
535 /*isDoConcurrent*/ std::is_same_v<OpType, fir::DeclareReductionOp>);
536 }
537
538 if (fir::isa_trivial(ty)) {
539 if (isByRef) {
540 // alloc region
541 builder.setInsertionPointToEnd(allocBlock);
542 mlir::Value alloca = builder.create<fir::AllocaOp>(loc, ty);
543 yield(alloca);
544 return;
545 }
546 // by val
547 yield(initValue);
548 return;
549 }
550 assert(isByRef && "passing non-trivial types by val is unsupported");
551
552 // alloc region
553 builder.setInsertionPointToEnd(allocBlock);
554 mlir::Value boxAlloca = builder.create<fir::AllocaOp>(loc, ty);
555 yield(boxAlloca);
556}
557
558template <typename OpType>
559OpType ReductionProcessor::createDeclareReduction(
560 AbstractConverter &converter, llvm::StringRef reductionOpName,
561 const ReductionIdentifier redId, mlir::Type type, mlir::Location loc,
562 bool isByRef) {
563 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
564 mlir::OpBuilder::InsertionGuard guard(builder);
565 mlir::ModuleOp module = builder.getModule();
566
567 assert(!reductionOpName.empty());
568
569 auto decl = module.lookupSymbol<OpType>(reductionOpName);
570 if (decl)
571 return decl;
572
573 mlir::OpBuilder modBuilder(module.getBodyRegion());
574 mlir::Type valTy = fir::unwrapRefType(type);
575 if (!isByRef)
576 type = valTy;
577
578 decl = modBuilder.create<OpType>(loc, reductionOpName, type);
579 createReductionAllocAndInitRegions(converter, loc, decl, redId, type,
580 isByRef);
581
582 builder.createBlock(&decl.getReductionRegion(),
583 decl.getReductionRegion().end(), {type, type},
584 {loc, loc});
585
586 builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
587 mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
588 mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
589 genCombiner<OpType>(builder, loc, redId, type, op1, op2, isByRef);
590
591 return decl;
592}
593
594static bool doReductionByRef(mlir::Value reductionVar) {
595 if (forceByrefReduction)
596 return true;
597
598 if (auto declare =
599 mlir::dyn_cast<hlfir::DeclareOp>(reductionVar.getDefiningOp()))
600 reductionVar = declare.getMemref();
601
602 if (!fir::isa_trivial(fir::unwrapRefType(reductionVar.getType())))
603 return true;
604
605 return false;
606}
607
608template <typename OpType, typename RedOperatorListTy>
609bool ReductionProcessor::processReductionArguments(
610 mlir::Location currentLocation, lower::AbstractConverter &converter,
611 const RedOperatorListTy &redOperatorList,
612 llvm::SmallVectorImpl<mlir::Value> &reductionVars,
613 llvm::SmallVectorImpl<bool> &reduceVarByRef,
614 llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
615 const llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols) {
616 if constexpr (std::is_same_v<RedOperatorListTy,
617 omp::clause::ReductionOperatorList>) {
618 // For OpenMP reduction clauses, check if the reduction operator is
619 // supported.
620 assert(redOperatorList.size() == 1 && "Expecting single operator");
621 const Fortran::lower::omp::clause::ReductionOperator &redOperator =
622 redOperatorList.front();
623
624 if (!std::holds_alternative<omp::clause::DefinedOperator>(redOperator.u)) {
625 if (const auto *reductionIntrinsic =
626 std::get_if<omp::clause::ProcedureDesignator>(&redOperator.u)) {
627 if (!ReductionProcessor::supportedIntrinsicProcReduction(
628 *reductionIntrinsic)) {
629 return false;
630 }
631 } else {
632 return false;
633 }
634 }
635 }
636
637 // Reduction variable processing common to both intrinsic operators and
638 // procedure designators
639 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
640 mlir::OpBuilder::InsertPoint dcIP;
641 constexpr bool isDoConcurrent =
642 std::is_same_v<OpType, fir::DeclareReductionOp>;
643
644 if (isDoConcurrent) {
645 dcIP = builder.saveInsertionPoint();
646 builder.setInsertionPoint(
647 builder.getRegion().getParentOfType<fir::DoConcurrentOp>());
648 }
649
650 for (const semantics::Symbol *symbol : reductionSymbols) {
651 mlir::Value symVal = converter.getSymbolAddress(*symbol);
652
653 if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
654 symVal = declOp.getBase();
655
656 mlir::Type eleType;
657 auto refType = mlir::dyn_cast_or_null<fir::ReferenceType>(symVal.getType());
658 if (refType)
659 eleType = refType.getEleTy();
660 else
661 eleType = symVal.getType();
662
663 // all arrays must be boxed so that we have convenient access to all the
664 // information needed to iterate over the array
665 if (mlir::isa<fir::SequenceType>(eleType)) {
666 // For Host associated symbols, use `SymbolBox` instead
667 lower::SymbolBox symBox = converter.lookupOneLevelUpSymbol(*symbol);
668 hlfir::Entity entity{symBox.getAddr()};
669 entity = genVariableBox(currentLocation, builder, entity);
670 mlir::Value box = entity.getBase();
671
672 // Always pass the box by reference so that the OpenMP dialect
673 // verifiers don't need to know anything about fir.box
674 auto alloca =
675 builder.create<fir::AllocaOp>(currentLocation, box.getType());
676 builder.create<fir::StoreOp>(currentLocation, box, alloca);
677
678 symVal = alloca;
679 } else if (mlir::isa<fir::BaseBoxType>(symVal.getType())) {
680 // boxed arrays are passed as values not by reference. Unfortunately,
681 // we can't pass a box by value to omp.redution_declare, so turn it
682 // into a reference
683 auto oldIP = builder.saveInsertionPoint();
684 builder.setInsertionPointToStart(builder.getAllocaBlock());
685 auto alloca =
686 builder.create<fir::AllocaOp>(currentLocation, symVal.getType());
687 builder.restoreInsertionPoint(oldIP);
688 builder.create<fir::StoreOp>(currentLocation, symVal, alloca);
689 symVal = alloca;
690 }
691
692 // this isn't the same as the by-val and by-ref passing later in the
693 // pipeline. Both styles assume that the variable is a reference at
694 // this point
695 assert(fir::isa_ref_type(symVal.getType()) &&
696 "reduction input var is passed by reference");
697 mlir::Type elementType = fir::dyn_cast_ptrEleTy(symVal.getType());
698 const bool symIsVolatile = fir::isa_volatile_type(symVal.getType());
699 mlir::Type refTy = fir::ReferenceType::get(elementType, symIsVolatile);
700
701 reductionVars.push_back(
702 builder.createConvert(currentLocation, refTy, symVal));
703 reduceVarByRef.push_back(doReductionByRef(symVal));
704 }
705
706 unsigned idx = 0;
707 for (auto [symVal, isByRef] : llvm::zip(reductionVars, reduceVarByRef)) {
708 auto redType = mlir::cast<fir::ReferenceType>(symVal.getType());
709 const auto &kindMap = builder.getKindMap();
710 std::string reductionName;
711 ReductionIdentifier redId;
712
713 if constexpr (std::is_same_v<RedOperatorListTy,
714 omp::clause::ReductionOperatorList>) {
715 const Fortran::lower::omp::clause::ReductionOperator &redOperator =
716 redOperatorList.front();
717 if (const auto &redDefinedOp =
718 std::get_if<omp::clause::DefinedOperator>(&redOperator.u)) {
719 const auto &intrinsicOp{
720 std::get<omp::clause::DefinedOperator::IntrinsicOperator>(
721 redDefinedOp->u)};
722 redId = getReductionType(intrinsicOp);
723 switch (redId) {
724 case ReductionIdentifier::ADD:
725 case ReductionIdentifier::MULTIPLY:
726 case ReductionIdentifier::AND:
727 case ReductionIdentifier::EQV:
728 case ReductionIdentifier::OR:
729 case ReductionIdentifier::NEQV:
730 break;
731 default:
732 TODO(currentLocation,
733 "Reduction of some intrinsic operators is not supported");
734 break;
735 }
736
737 reductionName = getReductionName(redId, kindMap, redType, isByRef);
738 } else if (const auto *reductionIntrinsic =
739 std::get_if<omp::clause::ProcedureDesignator>(
740 &redOperator.u)) {
741 if (!ReductionProcessor::supportedIntrinsicProcReduction(
742 *reductionIntrinsic)) {
743 TODO(currentLocation, "Unsupported intrinsic proc reduction");
744 }
745 redId = getReductionType(*reductionIntrinsic);
746 reductionName =
747 getReductionName(getRealName(*reductionIntrinsic).ToString(),
748 kindMap, redType, isByRef);
749 } else {
750 TODO(currentLocation, "Unexpected reduction type");
751 }
752 } else {
753 // `do concurrent` reductions
754 redId = getReductionType(redOperatorList[idx]);
755 reductionName = getReductionName(redId, kindMap, redType, isByRef);
756 }
757
758 OpType decl = createDeclareReduction<OpType>(
759 converter, reductionName, redId, redType, currentLocation, isByRef);
760 reductionDeclSymbols.push_back(
761 mlir::SymbolRefAttr::get(builder.getContext(), decl.getSymName()));
762 ++idx;
763 }
764
765 if (isDoConcurrent)
766 builder.restoreInsertionPoint(dcIP);
767
768 return true;
769}
770
771const semantics::SourceName
772ReductionProcessor::getRealName(const semantics::Symbol *symbol) {
773 return symbol->GetUltimate().name();
774}
775
776const semantics::SourceName
777ReductionProcessor::getRealName(const omp::clause::ProcedureDesignator &pd) {
778 return getRealName(pd.v.sym());
779}
780
781int ReductionProcessor::getOperationIdentity(ReductionIdentifier redId,
782 mlir::Location loc) {
783 switch (redId) {
784 case ReductionIdentifier::ADD:
785 case ReductionIdentifier::OR:
786 case ReductionIdentifier::NEQV:
787 return 0;
788 case ReductionIdentifier::MULTIPLY:
789 case ReductionIdentifier::AND:
790 case ReductionIdentifier::EQV:
791 return 1;
792 default:
793 TODO(loc, "Reduction of some intrinsic operators is not supported");
794 }
795}
796
797} // namespace omp
798} // namespace lower
799} // namespace Fortran
800

source code of flang/lib/Lower/Support/ReductionProcessor.cpp