1//===-- ReductionProcessor.cpp ----------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10//
11//===----------------------------------------------------------------------===//
12
13#include "ReductionProcessor.h"
14
15#include "flang/Lower/AbstractConverter.h"
16#include "flang/Lower/ConvertType.h"
17#include "flang/Lower/Support/PrivateReductionUtils.h"
18#include "flang/Lower/SymbolMap.h"
19#include "flang/Optimizer/Builder/Complex.h"
20#include "flang/Optimizer/Builder/HLFIRTools.h"
21#include "flang/Optimizer/Builder/Todo.h"
22#include "flang/Optimizer/Dialect/FIRType.h"
23#include "flang/Optimizer/HLFIR/HLFIROps.h"
24#include "flang/Optimizer/Support/FatalError.h"
25#include "flang/Parser/tools.h"
26#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
27#include "llvm/Support/CommandLine.h"
28#include <type_traits>
29
30static llvm::cl::opt<bool> forceByrefReduction(
31 "force-byref-reduction",
32 llvm::cl::desc("Pass all reduction arguments by reference"),
33 llvm::cl::Hidden);
34
35using ReductionModifier =
36 Fortran::lower::omp::clause::Reduction::ReductionModifier;
37
38namespace Fortran {
39namespace lower {
40namespace omp {
41
42// explicit template declarations
43template void
44ReductionProcessor::processReductionArguments<omp::clause::Reduction>(
45 mlir::Location currentLocation, lower::AbstractConverter &converter,
46 const omp::clause::Reduction &reduction,
47 llvm::SmallVectorImpl<mlir::Value> &reductionVars,
48 llvm::SmallVectorImpl<bool> &reduceVarByRef,
49 llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
50 llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols,
51 mlir::omp::ReductionModifierAttr *reductionMod);
52
53template void
54ReductionProcessor::processReductionArguments<omp::clause::TaskReduction>(
55 mlir::Location currentLocation, lower::AbstractConverter &converter,
56 const omp::clause::TaskReduction &reduction,
57 llvm::SmallVectorImpl<mlir::Value> &reductionVars,
58 llvm::SmallVectorImpl<bool> &reduceVarByRef,
59 llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
60 llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols,
61 mlir::omp::ReductionModifierAttr *reductionMod);
62
63template void
64ReductionProcessor::processReductionArguments<omp::clause::InReduction>(
65 mlir::Location currentLocation, lower::AbstractConverter &converter,
66 const omp::clause::InReduction &reduction,
67 llvm::SmallVectorImpl<mlir::Value> &reductionVars,
68 llvm::SmallVectorImpl<bool> &reduceVarByRef,
69 llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
70 llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols,
71 mlir::omp::ReductionModifierAttr *reductionMod);
72
73ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
74 const omp::clause::ProcedureDesignator &pd) {
75 auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>(
76 getRealName(pd.v.sym()).ToString())
77 .Case("max", ReductionIdentifier::MAX)
78 .Case("min", ReductionIdentifier::MIN)
79 .Case("iand", ReductionIdentifier::IAND)
80 .Case("ior", ReductionIdentifier::IOR)
81 .Case("ieor", ReductionIdentifier::IEOR)
82 .Default(std::nullopt);
83 assert(redType && "Invalid Reduction");
84 return *redType;
85}
86
87ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
88 omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp) {
89 switch (intrinsicOp) {
90 case omp::clause::DefinedOperator::IntrinsicOperator::Add:
91 return ReductionIdentifier::ADD;
92 case omp::clause::DefinedOperator::IntrinsicOperator::Subtract:
93 return ReductionIdentifier::SUBTRACT;
94 case omp::clause::DefinedOperator::IntrinsicOperator::Multiply:
95 return ReductionIdentifier::MULTIPLY;
96 case omp::clause::DefinedOperator::IntrinsicOperator::AND:
97 return ReductionIdentifier::AND;
98 case omp::clause::DefinedOperator::IntrinsicOperator::EQV:
99 return ReductionIdentifier::EQV;
100 case omp::clause::DefinedOperator::IntrinsicOperator::OR:
101 return ReductionIdentifier::OR;
102 case omp::clause::DefinedOperator::IntrinsicOperator::NEQV:
103 return ReductionIdentifier::NEQV;
104 default:
105 llvm_unreachable("unexpected intrinsic operator in reduction");
106 }
107}
108
109bool ReductionProcessor::supportedIntrinsicProcReduction(
110 const omp::clause::ProcedureDesignator &pd) {
111 semantics::Symbol *sym = pd.v.sym();
112 if (!sym->GetUltimate().attrs().test(semantics::Attr::INTRINSIC))
113 return false;
114 auto redType = llvm::StringSwitch<bool>(getRealName(sym).ToString())
115 .Case("max", true)
116 .Case("min", true)
117 .Case("iand", true)
118 .Case("ior", true)
119 .Case("ieor", true)
120 .Default(false);
121 return redType;
122}
123
124std::string
125ReductionProcessor::getReductionName(llvm::StringRef name,
126 const fir::KindMapping &kindMap,
127 mlir::Type ty, bool isByRef) {
128 ty = fir::unwrapRefType(ty);
129
130 // extra string to distinguish reduction functions for variables passed by
131 // reference
132 llvm::StringRef byrefAddition{""};
133 if (isByRef)
134 byrefAddition = "_byref";
135
136 return fir::getTypeAsString(ty, kindMap, (name + byrefAddition).str());
137}
138
139std::string ReductionProcessor::getReductionName(
140 omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp,
141 const fir::KindMapping &kindMap, mlir::Type ty, bool isByRef) {
142 std::string reductionName;
143
144 switch (intrinsicOp) {
145 case omp::clause::DefinedOperator::IntrinsicOperator::Add:
146 reductionName = "add_reduction";
147 break;
148 case omp::clause::DefinedOperator::IntrinsicOperator::Multiply:
149 reductionName = "multiply_reduction";
150 break;
151 case omp::clause::DefinedOperator::IntrinsicOperator::AND:
152 return "and_reduction";
153 case omp::clause::DefinedOperator::IntrinsicOperator::EQV:
154 return "eqv_reduction";
155 case omp::clause::DefinedOperator::IntrinsicOperator::OR:
156 return "or_reduction";
157 case omp::clause::DefinedOperator::IntrinsicOperator::NEQV:
158 return "neqv_reduction";
159 default:
160 reductionName = "other_reduction";
161 break;
162 }
163
164 return getReductionName(reductionName, kindMap, ty, isByRef);
165}
166
167mlir::Value
168ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type,
169 ReductionIdentifier redId,
170 fir::FirOpBuilder &builder) {
171 type = fir::unwrapRefType(type);
172 if (!fir::isa_integer(type) && !fir::isa_real(type) &&
173 !fir::isa_complex(type) && !mlir::isa<fir::LogicalType>(type))
174 TODO(loc, "Reduction of some types is not supported");
175 switch (redId) {
176 case ReductionIdentifier::MAX: {
177 if (auto ty = mlir::dyn_cast<mlir::FloatType>(type)) {
178 const llvm::fltSemantics &sem = ty.getFloatSemantics();
179 return builder.createRealConstant(
180 loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true));
181 }
182 unsigned bits = type.getIntOrFloatBitWidth();
183 int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue();
184 return builder.createIntegerConstant(loc, type, minInt);
185 }
186 case ReductionIdentifier::MIN: {
187 if (auto ty = mlir::dyn_cast<mlir::FloatType>(type)) {
188 const llvm::fltSemantics &sem = ty.getFloatSemantics();
189 return builder.createRealConstant(
190 loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/false));
191 }
192 unsigned bits = type.getIntOrFloatBitWidth();
193 int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue();
194 return builder.createIntegerConstant(loc, type, maxInt);
195 }
196 case ReductionIdentifier::IOR: {
197 unsigned bits = type.getIntOrFloatBitWidth();
198 int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
199 return builder.createIntegerConstant(loc, type, zeroInt);
200 }
201 case ReductionIdentifier::IEOR: {
202 unsigned bits = type.getIntOrFloatBitWidth();
203 int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
204 return builder.createIntegerConstant(loc, type, zeroInt);
205 }
206 case ReductionIdentifier::IAND: {
207 unsigned bits = type.getIntOrFloatBitWidth();
208 int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue();
209 return builder.createIntegerConstant(loc, type, allOnInt);
210 }
211 case ReductionIdentifier::ADD:
212 case ReductionIdentifier::MULTIPLY:
213 case ReductionIdentifier::AND:
214 case ReductionIdentifier::OR:
215 case ReductionIdentifier::EQV:
216 case ReductionIdentifier::NEQV:
217 if (auto cplxTy = mlir::dyn_cast<mlir::ComplexType>(type)) {
218 mlir::Type realTy = cplxTy.getElementType();
219 mlir::Value initRe = builder.createRealConstant(
220 loc, realTy, getOperationIdentity(redId, loc));
221 mlir::Value initIm = builder.createRealConstant(loc, realTy, 0);
222
223 return fir::factory::Complex{builder, loc}.createComplex(type, initRe,
224 initIm);
225 }
226 if (mlir::isa<mlir::FloatType>(type))
227 return builder.create<mlir::arith::ConstantOp>(
228 loc, type,
229 builder.getFloatAttr(type, (double)getOperationIdentity(redId, loc)));
230
231 if (mlir::isa<fir::LogicalType>(type)) {
232 mlir::Value intConst = builder.create<mlir::arith::ConstantOp>(
233 loc, builder.getI1Type(),
234 builder.getIntegerAttr(builder.getI1Type(),
235 getOperationIdentity(redId, loc)));
236 return builder.createConvert(loc, type, intConst);
237 }
238
239 return builder.create<mlir::arith::ConstantOp>(
240 loc, type,
241 builder.getIntegerAttr(type, getOperationIdentity(redId, loc)));
242 case ReductionIdentifier::ID:
243 case ReductionIdentifier::USER_DEF_OP:
244 case ReductionIdentifier::SUBTRACT:
245 TODO(loc, "Reduction of some identifier types is not supported");
246 }
247 llvm_unreachable("Unhandled Reduction identifier : getReductionInitValue");
248}
249
250mlir::Value ReductionProcessor::createScalarCombiner(
251 fir::FirOpBuilder &builder, mlir::Location loc, ReductionIdentifier redId,
252 mlir::Type type, mlir::Value op1, mlir::Value op2) {
253 mlir::Value reductionOp;
254 type = fir::unwrapRefType(type);
255 switch (redId) {
256 case ReductionIdentifier::MAX:
257 reductionOp =
258 getReductionOperation<mlir::arith::MaxNumFOp, mlir::arith::MaxSIOp>(
259 builder, type, loc, op1, op2);
260 break;
261 case ReductionIdentifier::MIN:
262 reductionOp =
263 getReductionOperation<mlir::arith::MinNumFOp, mlir::arith::MinSIOp>(
264 builder, type, loc, op1, op2);
265 break;
266 case ReductionIdentifier::IOR:
267 assert((type.isIntOrIndex()) && "only integer is expected");
268 reductionOp = builder.create<mlir::arith::OrIOp>(loc, op1, op2);
269 break;
270 case ReductionIdentifier::IEOR:
271 assert((type.isIntOrIndex()) && "only integer is expected");
272 reductionOp = builder.create<mlir::arith::XOrIOp>(loc, op1, op2);
273 break;
274 case ReductionIdentifier::IAND:
275 assert((type.isIntOrIndex()) && "only integer is expected");
276 reductionOp = builder.create<mlir::arith::AndIOp>(loc, op1, op2);
277 break;
278 case ReductionIdentifier::ADD:
279 reductionOp =
280 getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp,
281 fir::AddcOp>(builder, type, loc, op1, op2);
282 break;
283 case ReductionIdentifier::MULTIPLY:
284 reductionOp =
285 getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp,
286 fir::MulcOp>(builder, type, loc, op1, op2);
287 break;
288 case ReductionIdentifier::AND: {
289 mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
290 mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
291
292 mlir::Value andiOp = builder.create<mlir::arith::AndIOp>(loc, op1I1, op2I1);
293
294 reductionOp = builder.createConvert(loc, type, andiOp);
295 break;
296 }
297 case ReductionIdentifier::OR: {
298 mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
299 mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
300
301 mlir::Value oriOp = builder.create<mlir::arith::OrIOp>(loc, op1I1, op2I1);
302
303 reductionOp = builder.createConvert(loc, type, oriOp);
304 break;
305 }
306 case ReductionIdentifier::EQV: {
307 mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
308 mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
309
310 mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>(
311 loc, mlir::arith::CmpIPredicate::eq, op1I1, op2I1);
312
313 reductionOp = builder.createConvert(loc, type, cmpiOp);
314 break;
315 }
316 case ReductionIdentifier::NEQV: {
317 mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
318 mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
319
320 mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>(
321 loc, mlir::arith::CmpIPredicate::ne, op1I1, op2I1);
322
323 reductionOp = builder.createConvert(loc, type, cmpiOp);
324 break;
325 }
326 default:
327 TODO(loc, "Reduction of some intrinsic operators is not supported");
328 }
329
330 return reductionOp;
331}
332
333/// Create reduction combiner region for reduction variables which are boxed
334/// arrays
335static void genBoxCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
336 ReductionProcessor::ReductionIdentifier redId,
337 fir::BaseBoxType boxTy, mlir::Value lhs,
338 mlir::Value rhs) {
339 fir::SequenceType seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(
340 fir::unwrapRefType(boxTy.getEleTy()));
341 fir::HeapType heapTy =
342 mlir::dyn_cast_or_null<fir::HeapType>(boxTy.getEleTy());
343 fir::PointerType ptrTy =
344 mlir::dyn_cast_or_null<fir::PointerType>(boxTy.getEleTy());
345 if ((!seqTy || seqTy.hasUnknownShape()) && !heapTy && !ptrTy)
346 TODO(loc, "Unsupported boxed type in OpenMP reduction");
347
348 // load fir.ref<fir.box<...>>
349 mlir::Value lhsAddr = lhs;
350 lhs = builder.create<fir::LoadOp>(loc, lhs);
351 rhs = builder.create<fir::LoadOp>(loc, rhs);
352
353 if ((heapTy || ptrTy) && !seqTy) {
354 // get box contents (heap pointers)
355 lhs = builder.create<fir::BoxAddrOp>(loc, lhs);
356 rhs = builder.create<fir::BoxAddrOp>(loc, rhs);
357 mlir::Value lhsValAddr = lhs;
358
359 // load heap pointers
360 lhs = builder.create<fir::LoadOp>(loc, lhs);
361 rhs = builder.create<fir::LoadOp>(loc, rhs);
362
363 mlir::Type eleTy = heapTy ? heapTy.getEleTy() : ptrTy.getEleTy();
364
365 mlir::Value result = ReductionProcessor::createScalarCombiner(
366 builder, loc, redId, eleTy, lhs, rhs);
367 builder.create<fir::StoreOp>(loc, result, lhsValAddr);
368 builder.create<mlir::omp::YieldOp>(loc, lhsAddr);
369 return;
370 }
371
372 // Get ShapeShift with default lower bounds. This makes it possible to use
373 // unmodified LoopNest's indices with ArrayCoorOp.
374 fir::ShapeShiftOp shapeShift =
375 getShapeShift(builder, loc, lhs,
376 /*cannotHaveNonDefaultLowerBounds=*/false,
377 /*useDefaultLowerBounds=*/true);
378
379 // Iterate over array elements, applying the equivalent scalar reduction:
380
381 // F2018 5.4.10.2: Unallocated allocatable variables may not be referenced
382 // and so no null check is needed here before indexing into the (possibly
383 // allocatable) arrays.
384
385 // A hlfir::elemental here gets inlined with a temporary so create the
386 // loop nest directly.
387 // This function already controls all of the code in this region so we
388 // know this won't miss any opportuinties for clever elemental inlining
389 hlfir::LoopNest nest = hlfir::genLoopNest(
390 loc, builder, shapeShift.getExtents(), /*isUnordered=*/true);
391 builder.setInsertionPointToStart(nest.body);
392 const bool seqIsVolatile = fir::isa_volatile_type(seqTy.getEleTy());
393 mlir::Type refTy = fir::ReferenceType::get(seqTy.getEleTy(), seqIsVolatile);
394 auto lhsEleAddr = builder.create<fir::ArrayCoorOp>(
395 loc, refTy, lhs, shapeShift, /*slice=*/mlir::Value{},
396 nest.oneBasedIndices, /*typeparms=*/mlir::ValueRange{});
397 auto rhsEleAddr = builder.create<fir::ArrayCoorOp>(
398 loc, refTy, rhs, shapeShift, /*slice=*/mlir::Value{},
399 nest.oneBasedIndices, /*typeparms=*/mlir::ValueRange{});
400 auto lhsEle = builder.create<fir::LoadOp>(loc, lhsEleAddr);
401 auto rhsEle = builder.create<fir::LoadOp>(loc, rhsEleAddr);
402 mlir::Value scalarReduction = ReductionProcessor::createScalarCombiner(
403 builder, loc, redId, refTy, lhsEle, rhsEle);
404 builder.create<fir::StoreOp>(loc, scalarReduction, lhsEleAddr);
405
406 builder.setInsertionPointAfter(nest.outerOp);
407 builder.create<mlir::omp::YieldOp>(loc, lhsAddr);
408}
409
410// generate combiner region for reduction operations
411static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
412 ReductionProcessor::ReductionIdentifier redId,
413 mlir::Type ty, mlir::Value lhs, mlir::Value rhs,
414 bool isByRef) {
415 ty = fir::unwrapRefType(ty);
416
417 if (fir::isa_trivial(ty)) {
418 mlir::Value lhsLoaded = builder.loadIfRef(loc, lhs);
419 mlir::Value rhsLoaded = builder.loadIfRef(loc, rhs);
420
421 mlir::Value result = ReductionProcessor::createScalarCombiner(
422 builder, loc, redId, ty, lhsLoaded, rhsLoaded);
423 if (isByRef) {
424 builder.create<fir::StoreOp>(loc, result, lhs);
425 builder.create<mlir::omp::YieldOp>(loc, lhs);
426 } else {
427 builder.create<mlir::omp::YieldOp>(loc, result);
428 }
429 return;
430 }
431 // all arrays should have been boxed
432 if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) {
433 genBoxCombiner(builder, loc, redId, boxTy, lhs, rhs);
434 return;
435 }
436
437 TODO(loc, "OpenMP genCombiner for unsupported reduction variable type");
438}
439
440// like fir::unwrapSeqOrBoxedSeqType except it also works for non-sequence boxes
441static mlir::Type unwrapSeqOrBoxedType(mlir::Type ty) {
442 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty))
443 return seqTy.getEleTy();
444 if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) {
445 auto eleTy = fir::unwrapRefType(boxTy.getEleTy());
446 if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(eleTy))
447 return seqTy.getEleTy();
448 return eleTy;
449 }
450 return ty;
451}
452
453static void createReductionAllocAndInitRegions(
454 AbstractConverter &converter, mlir::Location loc,
455 mlir::omp::DeclareReductionOp &reductionDecl,
456 const ReductionProcessor::ReductionIdentifier redId, mlir::Type type,
457 bool isByRef) {
458 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
459 auto yield = [&](mlir::Value ret) {
460 builder.create<mlir::omp::YieldOp>(loc, ret);
461 };
462
463 mlir::Block *allocBlock = nullptr;
464 mlir::Block *initBlock = nullptr;
465 if (isByRef) {
466 allocBlock =
467 builder.createBlock(&reductionDecl.getAllocRegion(),
468 reductionDecl.getAllocRegion().end(), {}, {});
469 initBlock = builder.createBlock(&reductionDecl.getInitializerRegion(),
470 reductionDecl.getInitializerRegion().end(),
471 {type, type}, {loc, loc});
472 } else {
473 initBlock = builder.createBlock(&reductionDecl.getInitializerRegion(),
474 reductionDecl.getInitializerRegion().end(),
475 {type}, {loc});
476 }
477
478 mlir::Type ty = fir::unwrapRefType(type);
479 builder.setInsertionPointToEnd(initBlock);
480 mlir::Value initValue = ReductionProcessor::getReductionInitValue(
481 loc, unwrapSeqOrBoxedType(ty), redId, builder);
482
483 if (isByRef) {
484 populateByRefInitAndCleanupRegions(
485 converter, loc, type, initValue, initBlock,
486 reductionDecl.getInitializerAllocArg(),
487 reductionDecl.getInitializerMoldArg(), reductionDecl.getCleanupRegion(),
488 DeclOperationKind::Reduction);
489 }
490
491 if (fir::isa_trivial(ty)) {
492 if (isByRef) {
493 // alloc region
494 builder.setInsertionPointToEnd(allocBlock);
495 mlir::Value alloca = builder.create<fir::AllocaOp>(loc, ty);
496 yield(alloca);
497 return;
498 }
499 // by val
500 yield(initValue);
501 return;
502 }
503 assert(isByRef && "passing non-trivial types by val is unsupported");
504
505 // alloc region
506 builder.setInsertionPointToEnd(allocBlock);
507 mlir::Value boxAlloca = builder.create<fir::AllocaOp>(loc, ty);
508 yield(boxAlloca);
509}
510
511mlir::omp::DeclareReductionOp ReductionProcessor::createDeclareReduction(
512 AbstractConverter &converter, llvm::StringRef reductionOpName,
513 const ReductionIdentifier redId, mlir::Type type, mlir::Location loc,
514 bool isByRef) {
515 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
516 mlir::OpBuilder::InsertionGuard guard(builder);
517 mlir::ModuleOp module = builder.getModule();
518
519 assert(!reductionOpName.empty());
520
521 auto decl =
522 module.lookupSymbol<mlir::omp::DeclareReductionOp>(reductionOpName);
523 if (decl)
524 return decl;
525
526 mlir::OpBuilder modBuilder(module.getBodyRegion());
527 mlir::Type valTy = fir::unwrapRefType(type);
528 if (!isByRef)
529 type = valTy;
530
531 decl = modBuilder.create<mlir::omp::DeclareReductionOp>(loc, reductionOpName,
532 type);
533 createReductionAllocAndInitRegions(converter, loc, decl, redId, type,
534 isByRef);
535
536 builder.createBlock(&decl.getReductionRegion(),
537 decl.getReductionRegion().end(), {type, type},
538 {loc, loc});
539
540 builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
541 mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
542 mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
543 genCombiner(builder, loc, redId, type, op1, op2, isByRef);
544
545 return decl;
546}
547
548static bool doReductionByRef(mlir::Value reductionVar) {
549 if (forceByrefReduction)
550 return true;
551
552 if (auto declare =
553 mlir::dyn_cast<hlfir::DeclareOp>(reductionVar.getDefiningOp()))
554 reductionVar = declare.getMemref();
555
556 if (!fir::isa_trivial(fir::unwrapRefType(reductionVar.getType())))
557 return true;
558
559 return false;
560}
561
562mlir::omp::ReductionModifier translateReductionModifier(ReductionModifier mod) {
563 switch (mod) {
564 case ReductionModifier::Default:
565 return mlir::omp::ReductionModifier::defaultmod;
566 case ReductionModifier::Inscan:
567 return mlir::omp::ReductionModifier::inscan;
568 case ReductionModifier::Task:
569 return mlir::omp::ReductionModifier::task;
570 }
571 return mlir::omp::ReductionModifier::defaultmod;
572}
573
574template <class T>
575void ReductionProcessor::processReductionArguments(
576 mlir::Location currentLocation, lower::AbstractConverter &converter,
577 const T &reduction, llvm::SmallVectorImpl<mlir::Value> &reductionVars,
578 llvm::SmallVectorImpl<bool> &reduceVarByRef,
579 llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
580 llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols,
581 mlir::omp::ReductionModifierAttr *reductionMod) {
582 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
583
584 if constexpr (std::is_same_v<T, omp::clause::Reduction>) {
585 auto mod = std::get<std::optional<ReductionModifier>>(reduction.t);
586 if (mod.has_value()) {
587 if (mod.value() == ReductionModifier::Task)
588 TODO(currentLocation, "Reduction modifier `task` is not supported");
589 else
590 *reductionMod = mlir::omp::ReductionModifierAttr::get(
591 firOpBuilder.getContext(), translateReductionModifier(mod.value()));
592 }
593 }
594
595 mlir::omp::DeclareReductionOp decl;
596 const auto &redOperatorList{
597 std::get<typename T::ReductionIdentifiers>(reduction.t)};
598 assert(redOperatorList.size() == 1 && "Expecting single operator");
599 const auto &redOperator = redOperatorList.front();
600 const auto &objectList{std::get<omp::ObjectList>(reduction.t)};
601
602 if (!std::holds_alternative<omp::clause::DefinedOperator>(redOperator.u)) {
603 if (const auto *reductionIntrinsic =
604 std::get_if<omp::clause::ProcedureDesignator>(&redOperator.u)) {
605 if (!ReductionProcessor::supportedIntrinsicProcReduction(
606 *reductionIntrinsic)) {
607 return;
608 }
609 } else {
610 return;
611 }
612 }
613
614 // Reduction variable processing common to both intrinsic operators and
615 // procedure designators
616 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
617 for (const Object &object : objectList) {
618 const semantics::Symbol *symbol = object.sym();
619 reductionSymbols.push_back(symbol);
620 mlir::Value symVal = converter.getSymbolAddress(*symbol);
621 mlir::Type eleType;
622 auto refType = mlir::dyn_cast_or_null<fir::ReferenceType>(symVal.getType());
623 if (refType)
624 eleType = refType.getEleTy();
625 else
626 eleType = symVal.getType();
627
628 // all arrays must be boxed so that we have convenient access to all the
629 // information needed to iterate over the array
630 if (mlir::isa<fir::SequenceType>(eleType)) {
631 // For Host associated symbols, use `SymbolBox` instead
632 lower::SymbolBox symBox = converter.lookupOneLevelUpSymbol(*symbol);
633 hlfir::Entity entity{symBox.getAddr()};
634 entity = genVariableBox(currentLocation, builder, entity);
635 mlir::Value box = entity.getBase();
636
637 // Always pass the box by reference so that the OpenMP dialect
638 // verifiers don't need to know anything about fir.box
639 auto alloca =
640 builder.create<fir::AllocaOp>(currentLocation, box.getType());
641 builder.create<fir::StoreOp>(currentLocation, box, alloca);
642
643 symVal = alloca;
644 } else if (mlir::isa<fir::BaseBoxType>(symVal.getType())) {
645 // boxed arrays are passed as values not by reference. Unfortunately,
646 // we can't pass a box by value to omp.redution_declare, so turn it
647 // into a reference
648
649 auto alloca =
650 builder.create<fir::AllocaOp>(currentLocation, symVal.getType());
651 builder.create<fir::StoreOp>(currentLocation, symVal, alloca);
652 symVal = alloca;
653 } else if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) {
654 symVal = declOp.getBase();
655 }
656
657 // this isn't the same as the by-val and by-ref passing later in the
658 // pipeline. Both styles assume that the variable is a reference at
659 // this point
660 assert(fir::isa_ref_type(symVal.getType()) &&
661 "reduction input var is passed by reference");
662 mlir::Type elementType = fir::dyn_cast_ptrEleTy(symVal.getType());
663 const bool symIsVolatile = fir::isa_volatile_type(symVal.getType());
664 mlir::Type refTy = fir::ReferenceType::get(elementType, symIsVolatile);
665
666 reductionVars.push_back(
667 builder.createConvert(currentLocation, refTy, symVal));
668 reduceVarByRef.push_back(doReductionByRef(symVal));
669 }
670
671 for (auto [symVal, isByRef] : llvm::zip(reductionVars, reduceVarByRef)) {
672 auto redType = mlir::cast<fir::ReferenceType>(symVal.getType());
673 const auto &kindMap = firOpBuilder.getKindMap();
674 std::string reductionName;
675 ReductionIdentifier redId;
676 mlir::Type redNameTy = redType;
677 if (mlir::isa<fir::LogicalType>(redType.getEleTy()))
678 redNameTy = builder.getI1Type();
679
680 if (const auto &redDefinedOp =
681 std::get_if<omp::clause::DefinedOperator>(&redOperator.u)) {
682 const auto &intrinsicOp{
683 std::get<omp::clause::DefinedOperator::IntrinsicOperator>(
684 redDefinedOp->u)};
685 redId = getReductionType(intrinsicOp);
686 switch (redId) {
687 case ReductionIdentifier::ADD:
688 case ReductionIdentifier::MULTIPLY:
689 case ReductionIdentifier::AND:
690 case ReductionIdentifier::EQV:
691 case ReductionIdentifier::OR:
692 case ReductionIdentifier::NEQV:
693 break;
694 default:
695 TODO(currentLocation,
696 "Reduction of some intrinsic operators is not supported");
697 break;
698 }
699
700 reductionName =
701 getReductionName(intrinsicOp, kindMap, redNameTy, isByRef);
702 } else if (const auto *reductionIntrinsic =
703 std::get_if<omp::clause::ProcedureDesignator>(
704 &redOperator.u)) {
705 if (!ReductionProcessor::supportedIntrinsicProcReduction(
706 *reductionIntrinsic)) {
707 TODO(currentLocation, "Unsupported intrinsic proc reduction");
708 }
709 redId = getReductionType(*reductionIntrinsic);
710 reductionName =
711 getReductionName(getRealName(*reductionIntrinsic).ToString(), kindMap,
712 redNameTy, isByRef);
713 } else {
714 TODO(currentLocation, "Unexpected reduction type");
715 }
716
717 decl = createDeclareReduction(converter, reductionName, redId, redType,
718 currentLocation, isByRef);
719 reductionDeclSymbols.push_back(
720 mlir::SymbolRefAttr::get(firOpBuilder.getContext(), decl.getSymName()));
721 }
722}
723
724const semantics::SourceName
725ReductionProcessor::getRealName(const semantics::Symbol *symbol) {
726 return symbol->GetUltimate().name();
727}
728
729const semantics::SourceName
730ReductionProcessor::getRealName(const omp::clause::ProcedureDesignator &pd) {
731 return getRealName(pd.v.sym());
732}
733
734int ReductionProcessor::getOperationIdentity(ReductionIdentifier redId,
735 mlir::Location loc) {
736 switch (redId) {
737 case ReductionIdentifier::ADD:
738 case ReductionIdentifier::OR:
739 case ReductionIdentifier::NEQV:
740 return 0;
741 case ReductionIdentifier::MULTIPLY:
742 case ReductionIdentifier::AND:
743 case ReductionIdentifier::EQV:
744 return 1;
745 default:
746 TODO(loc, "Reduction of some intrinsic operators is not supported");
747 }
748}
749
750} // namespace omp
751} // namespace lower
752} // namespace Fortran
753

source code of flang/lib/Lower/OpenMP/ReductionProcessor.cpp