1//===-- ReductionProcessor.cpp ----------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10//
11//===----------------------------------------------------------------------===//
12
13#include "ReductionProcessor.h"
14
15#include "flang/Lower/AbstractConverter.h"
16#include "flang/Lower/ConvertType.h"
17#include "flang/Lower/SymbolMap.h"
18#include "flang/Optimizer/Builder/Complex.h"
19#include "flang/Optimizer/Builder/HLFIRTools.h"
20#include "flang/Optimizer/Builder/Todo.h"
21#include "flang/Optimizer/Dialect/FIRType.h"
22#include "flang/Optimizer/HLFIR/HLFIROps.h"
23#include "flang/Optimizer/Support/FatalError.h"
24#include "flang/Parser/tools.h"
25#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
26#include "llvm/Support/CommandLine.h"
27
28static llvm::cl::opt<bool> forceByrefReduction(
29 "force-byref-reduction",
30 llvm::cl::desc("Pass all reduction arguments by reference"),
31 llvm::cl::Hidden);
32
33namespace Fortran {
34namespace lower {
35namespace omp {
36
37ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
38 const omp::clause::ProcedureDesignator &pd) {
39 auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>(
40 getRealName(pd.v.id()).ToString())
41 .Case("max", ReductionIdentifier::MAX)
42 .Case("min", ReductionIdentifier::MIN)
43 .Case("iand", ReductionIdentifier::IAND)
44 .Case("ior", ReductionIdentifier::IOR)
45 .Case("ieor", ReductionIdentifier::IEOR)
46 .Default(std::nullopt);
47 assert(redType && "Invalid Reduction");
48 return *redType;
49}
50
51ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
52 omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp) {
53 switch (intrinsicOp) {
54 case omp::clause::DefinedOperator::IntrinsicOperator::Add:
55 return ReductionIdentifier::ADD;
56 case omp::clause::DefinedOperator::IntrinsicOperator::Subtract:
57 return ReductionIdentifier::SUBTRACT;
58 case omp::clause::DefinedOperator::IntrinsicOperator::Multiply:
59 return ReductionIdentifier::MULTIPLY;
60 case omp::clause::DefinedOperator::IntrinsicOperator::AND:
61 return ReductionIdentifier::AND;
62 case omp::clause::DefinedOperator::IntrinsicOperator::EQV:
63 return ReductionIdentifier::EQV;
64 case omp::clause::DefinedOperator::IntrinsicOperator::OR:
65 return ReductionIdentifier::OR;
66 case omp::clause::DefinedOperator::IntrinsicOperator::NEQV:
67 return ReductionIdentifier::NEQV;
68 default:
69 llvm_unreachable("unexpected intrinsic operator in reduction");
70 }
71}
72
73bool ReductionProcessor::supportedIntrinsicProcReduction(
74 const omp::clause::ProcedureDesignator &pd) {
75 Fortran::semantics::Symbol *sym = pd.v.id();
76 if (!sym->GetUltimate().attrs().test(Fortran::semantics::Attr::INTRINSIC))
77 return false;
78 auto redType = llvm::StringSwitch<bool>(getRealName(sym).ToString())
79 .Case("max", true)
80 .Case("min", true)
81 .Case("iand", true)
82 .Case("ior", true)
83 .Case("ieor", true)
84 .Default(false);
85 return redType;
86}
87
88std::string
89ReductionProcessor::getReductionName(llvm::StringRef name,
90 const fir::KindMapping &kindMap,
91 mlir::Type ty, bool isByRef) {
92 ty = fir::unwrapRefType(ty);
93
94 // extra string to distinguish reduction functions for variables passed by
95 // reference
96 llvm::StringRef byrefAddition{""};
97 if (isByRef)
98 byrefAddition = "_byref";
99
100 return fir::getTypeAsString(ty, kindMap, (name + byrefAddition).str());
101}
102
103std::string ReductionProcessor::getReductionName(
104 omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp,
105 const fir::KindMapping &kindMap, mlir::Type ty, bool isByRef) {
106 std::string reductionName;
107
108 switch (intrinsicOp) {
109 case omp::clause::DefinedOperator::IntrinsicOperator::Add:
110 reductionName = "add_reduction";
111 break;
112 case omp::clause::DefinedOperator::IntrinsicOperator::Multiply:
113 reductionName = "multiply_reduction";
114 break;
115 case omp::clause::DefinedOperator::IntrinsicOperator::AND:
116 return "and_reduction";
117 case omp::clause::DefinedOperator::IntrinsicOperator::EQV:
118 return "eqv_reduction";
119 case omp::clause::DefinedOperator::IntrinsicOperator::OR:
120 return "or_reduction";
121 case omp::clause::DefinedOperator::IntrinsicOperator::NEQV:
122 return "neqv_reduction";
123 default:
124 reductionName = "other_reduction";
125 break;
126 }
127
128 return getReductionName(reductionName, kindMap, ty, isByRef);
129}
130
131mlir::Value
132ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type,
133 ReductionIdentifier redId,
134 fir::FirOpBuilder &builder) {
135 type = fir::unwrapRefType(type);
136 if (!fir::isa_integer(type) && !fir::isa_real(type) &&
137 !fir::isa_complex(type) && !mlir::isa<fir::LogicalType>(type))
138 TODO(loc, "Reduction of some types is not supported");
139 switch (redId) {
140 case ReductionIdentifier::MAX: {
141 if (auto ty = type.dyn_cast<mlir::FloatType>()) {
142 const llvm::fltSemantics &sem = ty.getFloatSemantics();
143 return builder.createRealConstant(
144 loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true));
145 }
146 unsigned bits = type.getIntOrFloatBitWidth();
147 int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue();
148 return builder.createIntegerConstant(loc, type, minInt);
149 }
150 case ReductionIdentifier::MIN: {
151 if (auto ty = type.dyn_cast<mlir::FloatType>()) {
152 const llvm::fltSemantics &sem = ty.getFloatSemantics();
153 return builder.createRealConstant(
154 loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/false));
155 }
156 unsigned bits = type.getIntOrFloatBitWidth();
157 int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue();
158 return builder.createIntegerConstant(loc, type, maxInt);
159 }
160 case ReductionIdentifier::IOR: {
161 unsigned bits = type.getIntOrFloatBitWidth();
162 int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
163 return builder.createIntegerConstant(loc, type, zeroInt);
164 }
165 case ReductionIdentifier::IEOR: {
166 unsigned bits = type.getIntOrFloatBitWidth();
167 int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue();
168 return builder.createIntegerConstant(loc, type, zeroInt);
169 }
170 case ReductionIdentifier::IAND: {
171 unsigned bits = type.getIntOrFloatBitWidth();
172 int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue();
173 return builder.createIntegerConstant(loc, type, allOnInt);
174 }
175 case ReductionIdentifier::ADD:
176 case ReductionIdentifier::MULTIPLY:
177 case ReductionIdentifier::AND:
178 case ReductionIdentifier::OR:
179 case ReductionIdentifier::EQV:
180 case ReductionIdentifier::NEQV:
181 if (auto cplxTy = mlir::dyn_cast<fir::ComplexType>(type)) {
182 mlir::Type realTy =
183 Fortran::lower::convertReal(builder.getContext(), cplxTy.getFKind());
184 mlir::Value initRe = builder.createRealConstant(
185 loc, realTy, getOperationIdentity(redId, loc));
186 mlir::Value initIm = builder.createRealConstant(loc, realTy, 0);
187
188 return fir::factory::Complex{builder, loc}.createComplex(type, initRe,
189 initIm);
190 }
191 if (type.isa<mlir::FloatType>())
192 return builder.create<mlir::arith::ConstantOp>(
193 loc, type,
194 builder.getFloatAttr(type, (double)getOperationIdentity(redId, loc)));
195
196 if (type.isa<fir::LogicalType>()) {
197 mlir::Value intConst = builder.create<mlir::arith::ConstantOp>(
198 loc, builder.getI1Type(),
199 builder.getIntegerAttr(builder.getI1Type(),
200 getOperationIdentity(redId, loc)));
201 return builder.createConvert(loc, type, intConst);
202 }
203
204 return builder.create<mlir::arith::ConstantOp>(
205 loc, type,
206 builder.getIntegerAttr(type, getOperationIdentity(redId, loc)));
207 case ReductionIdentifier::ID:
208 case ReductionIdentifier::USER_DEF_OP:
209 case ReductionIdentifier::SUBTRACT:
210 TODO(loc, "Reduction of some identifier types is not supported");
211 }
212 llvm_unreachable("Unhandled Reduction identifier : getReductionInitValue");
213}
214
215mlir::Value ReductionProcessor::createScalarCombiner(
216 fir::FirOpBuilder &builder, mlir::Location loc, ReductionIdentifier redId,
217 mlir::Type type, mlir::Value op1, mlir::Value op2) {
218 mlir::Value reductionOp;
219 type = fir::unwrapRefType(type);
220 switch (redId) {
221 case ReductionIdentifier::MAX:
222 reductionOp =
223 getReductionOperation<mlir::arith::MaxNumFOp, mlir::arith::MaxSIOp>(
224 builder, type, loc, op1, op2);
225 break;
226 case ReductionIdentifier::MIN:
227 reductionOp =
228 getReductionOperation<mlir::arith::MinNumFOp, mlir::arith::MinSIOp>(
229 builder, type, loc, op1, op2);
230 break;
231 case ReductionIdentifier::IOR:
232 assert((type.isIntOrIndex()) && "only integer is expected");
233 reductionOp = builder.create<mlir::arith::OrIOp>(loc, op1, op2);
234 break;
235 case ReductionIdentifier::IEOR:
236 assert((type.isIntOrIndex()) && "only integer is expected");
237 reductionOp = builder.create<mlir::arith::XOrIOp>(loc, op1, op2);
238 break;
239 case ReductionIdentifier::IAND:
240 assert((type.isIntOrIndex()) && "only integer is expected");
241 reductionOp = builder.create<mlir::arith::AndIOp>(loc, op1, op2);
242 break;
243 case ReductionIdentifier::ADD:
244 reductionOp =
245 getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp,
246 fir::AddcOp>(builder, type, loc, op1, op2);
247 break;
248 case ReductionIdentifier::MULTIPLY:
249 reductionOp =
250 getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp,
251 fir::MulcOp>(builder, type, loc, op1, op2);
252 break;
253 case ReductionIdentifier::AND: {
254 mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
255 mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
256
257 mlir::Value andiOp = builder.create<mlir::arith::AndIOp>(loc, op1I1, op2I1);
258
259 reductionOp = builder.createConvert(loc, type, andiOp);
260 break;
261 }
262 case ReductionIdentifier::OR: {
263 mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
264 mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
265
266 mlir::Value oriOp = builder.create<mlir::arith::OrIOp>(loc, op1I1, op2I1);
267
268 reductionOp = builder.createConvert(loc, type, oriOp);
269 break;
270 }
271 case ReductionIdentifier::EQV: {
272 mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
273 mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
274
275 mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>(
276 loc, mlir::arith::CmpIPredicate::eq, op1I1, op2I1);
277
278 reductionOp = builder.createConvert(loc, type, cmpiOp);
279 break;
280 }
281 case ReductionIdentifier::NEQV: {
282 mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1);
283 mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2);
284
285 mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>(
286 loc, mlir::arith::CmpIPredicate::ne, op1I1, op2I1);
287
288 reductionOp = builder.createConvert(loc, type, cmpiOp);
289 break;
290 }
291 default:
292 TODO(loc, "Reduction of some intrinsic operators is not supported");
293 }
294
295 return reductionOp;
296}
297
298/// Generate a fir::ShapeShift op describing the provided boxed array.
299static fir::ShapeShiftOp getShapeShift(fir::FirOpBuilder &builder,
300 mlir::Location loc, mlir::Value box) {
301 fir::SequenceType sequenceType = mlir::cast<fir::SequenceType>(
302 hlfir::getFortranElementOrSequenceType(box.getType()));
303 const unsigned rank = sequenceType.getDimension();
304 llvm::SmallVector<mlir::Value> lbAndExtents;
305 lbAndExtents.reserve(rank * 2);
306
307 mlir::Type idxTy = builder.getIndexType();
308 for (unsigned i = 0; i < rank; ++i) {
309 // TODO: ideally we want to hoist box reads out of the critical section.
310 // We could do this by having box dimensions in block arguments like
311 // OpenACC does
312 mlir::Value dim = builder.createIntegerConstant(loc, idxTy, i);
313 auto dimInfo =
314 builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, box, dim);
315 lbAndExtents.push_back(dimInfo.getLowerBound());
316 lbAndExtents.push_back(dimInfo.getExtent());
317 }
318
319 auto shapeShiftTy = fir::ShapeShiftType::get(builder.getContext(), rank);
320 auto shapeShift =
321 builder.create<fir::ShapeShiftOp>(loc, shapeShiftTy, lbAndExtents);
322 return shapeShift;
323}
324
325/// Create reduction combiner region for reduction variables which are boxed
326/// arrays
327static void genBoxCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
328 ReductionProcessor::ReductionIdentifier redId,
329 fir::BaseBoxType boxTy, mlir::Value lhs,
330 mlir::Value rhs) {
331 fir::SequenceType seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(
332 fir::unwrapRefType(boxTy.getEleTy()));
333 fir::HeapType heapTy =
334 mlir::dyn_cast_or_null<fir::HeapType>(boxTy.getEleTy());
335 if ((!seqTy || seqTy.hasUnknownShape()) && !heapTy)
336 TODO(loc, "Unsupported boxed type in OpenMP reduction");
337
338 // load fir.ref<fir.box<...>>
339 mlir::Value lhsAddr = lhs;
340 lhs = builder.create<fir::LoadOp>(loc, lhs);
341 rhs = builder.create<fir::LoadOp>(loc, rhs);
342
343 if (heapTy && !seqTy) {
344 // get box contents (heap pointers)
345 lhs = builder.create<fir::BoxAddrOp>(loc, lhs);
346 rhs = builder.create<fir::BoxAddrOp>(loc, rhs);
347 mlir::Value lhsValAddr = lhs;
348
349 // load heap pointers
350 lhs = builder.create<fir::LoadOp>(loc, lhs);
351 rhs = builder.create<fir::LoadOp>(loc, rhs);
352
353 mlir::Value result = ReductionProcessor::createScalarCombiner(
354 builder, loc, redId, heapTy.getEleTy(), lhs, rhs);
355 builder.create<fir::StoreOp>(loc, result, lhsValAddr);
356 builder.create<mlir::omp::YieldOp>(loc, lhsAddr);
357 return;
358 }
359
360 fir::ShapeShiftOp shapeShift = getShapeShift(builder, loc, lhs);
361
362 // Iterate over array elements, applying the equivalent scalar reduction:
363
364 // F2018 5.4.10.2: Unallocated allocatable variables may not be referenced
365 // and so no null check is needed here before indexing into the (possibly
366 // allocatable) arrays.
367
368 // A hlfir::elemental here gets inlined with a temporary so create the
369 // loop nest directly.
370 // This function already controls all of the code in this region so we
371 // know this won't miss any opportuinties for clever elemental inlining
372 hlfir::LoopNest nest = hlfir::genLoopNest(
373 loc, builder, shapeShift.getExtents(), /*isUnordered=*/true);
374 builder.setInsertionPointToStart(nest.innerLoop.getBody());
375 mlir::Type refTy = fir::ReferenceType::get(seqTy.getEleTy());
376 auto lhsEleAddr = builder.create<fir::ArrayCoorOp>(
377 loc, refTy, lhs, shapeShift, /*slice=*/mlir::Value{},
378 nest.oneBasedIndices, /*typeparms=*/mlir::ValueRange{});
379 auto rhsEleAddr = builder.create<fir::ArrayCoorOp>(
380 loc, refTy, rhs, shapeShift, /*slice=*/mlir::Value{},
381 nest.oneBasedIndices, /*typeparms=*/mlir::ValueRange{});
382 auto lhsEle = builder.create<fir::LoadOp>(loc, lhsEleAddr);
383 auto rhsEle = builder.create<fir::LoadOp>(loc, rhsEleAddr);
384 mlir::Value scalarReduction = ReductionProcessor::createScalarCombiner(
385 builder, loc, redId, refTy, lhsEle, rhsEle);
386 builder.create<fir::StoreOp>(loc, scalarReduction, lhsEleAddr);
387
388 builder.setInsertionPointAfter(nest.outerLoop);
389 builder.create<mlir::omp::YieldOp>(loc, lhsAddr);
390}
391
392// generate combiner region for reduction operations
393static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
394 ReductionProcessor::ReductionIdentifier redId,
395 mlir::Type ty, mlir::Value lhs, mlir::Value rhs,
396 bool isByRef) {
397 ty = fir::unwrapRefType(ty);
398
399 if (fir::isa_trivial(ty)) {
400 mlir::Value lhsLoaded = builder.loadIfRef(loc, lhs);
401 mlir::Value rhsLoaded = builder.loadIfRef(loc, rhs);
402
403 mlir::Value result = ReductionProcessor::createScalarCombiner(
404 builder, loc, redId, ty, lhsLoaded, rhsLoaded);
405 if (isByRef) {
406 builder.create<fir::StoreOp>(loc, result, lhs);
407 builder.create<mlir::omp::YieldOp>(loc, lhs);
408 } else {
409 builder.create<mlir::omp::YieldOp>(loc, result);
410 }
411 return;
412 }
413 // all arrays should have been boxed
414 if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) {
415 genBoxCombiner(builder, loc, redId, boxTy, lhs, rhs);
416 return;
417 }
418
419 TODO(loc, "OpenMP genCombiner for unsupported reduction variable type");
420}
421
422static void
423createReductionCleanupRegion(fir::FirOpBuilder &builder, mlir::Location loc,
424 mlir::omp::DeclareReductionOp &reductionDecl) {
425 mlir::Type redTy = reductionDecl.getType();
426
427 mlir::Region &cleanupRegion = reductionDecl.getCleanupRegion();
428 assert(cleanupRegion.empty());
429 mlir::Block *block =
430 builder.createBlock(&cleanupRegion, cleanupRegion.end(), {redTy}, {loc});
431 builder.setInsertionPointToEnd(block);
432
433 auto typeError = [loc]() {
434 fir::emitFatalError(loc,
435 "Attempt to create an omp reduction cleanup region "
436 "for a type that wasn't allocated",
437 /*genCrashDiag=*/true);
438 };
439
440 mlir::Type valTy = fir::unwrapRefType(redTy);
441 if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(valTy)) {
442 if (!mlir::isa<fir::HeapType>(boxTy.getEleTy())) {
443 mlir::Type innerTy = fir::extractSequenceType(boxTy);
444 if (!mlir::isa<fir::SequenceType>(innerTy))
445 typeError();
446 }
447
448 mlir::Value arg = block->getArgument(0);
449 arg = builder.loadIfRef(loc, arg);
450 assert(mlir::isa<fir::BaseBoxType>(arg.getType()));
451
452 // Deallocate box
453 // The FIR type system doesn't nesecarrily know that this is a mutable box
454 // if we allocated the thread local array on the heap to avoid looped stack
455 // allocations.
456 mlir::Value addr =
457 hlfir::genVariableRawAddress(loc, builder, hlfir::Entity{arg});
458 mlir::Value isAllocated = builder.genIsNotNullAddr(loc, addr);
459 fir::IfOp ifOp =
460 builder.create<fir::IfOp>(loc, isAllocated, /*withElseRegion=*/false);
461 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
462
463 mlir::Value cast = builder.createConvert(
464 loc, fir::HeapType::get(fir::dyn_cast_ptrEleTy(addr.getType())), addr);
465 builder.create<fir::FreeMemOp>(loc, cast);
466
467 builder.setInsertionPointAfter(ifOp);
468 builder.create<mlir::omp::YieldOp>(loc);
469 return;
470 }
471
472 typeError();
473}
474
475// like fir::unwrapSeqOrBoxedSeqType except it also works for non-sequence boxes
476static mlir::Type unwrapSeqOrBoxedType(mlir::Type ty) {
477 if (auto seqTy = ty.dyn_cast<fir::SequenceType>())
478 return seqTy.getEleTy();
479 if (auto boxTy = ty.dyn_cast<fir::BaseBoxType>()) {
480 auto eleTy = fir::unwrapRefType(boxTy.getEleTy());
481 if (auto seqTy = eleTy.dyn_cast<fir::SequenceType>())
482 return seqTy.getEleTy();
483 return eleTy;
484 }
485 return ty;
486}
487
488static mlir::Value
489createReductionInitRegion(fir::FirOpBuilder &builder, mlir::Location loc,
490 mlir::omp::DeclareReductionOp &reductionDecl,
491 const ReductionProcessor::ReductionIdentifier redId,
492 mlir::Type type, bool isByRef) {
493 mlir::Type ty = fir::unwrapRefType(type);
494 mlir::Value initValue = ReductionProcessor::getReductionInitValue(
495 loc, unwrapSeqOrBoxedType(ty), redId, builder);
496
497 if (fir::isa_trivial(ty)) {
498 if (isByRef) {
499 mlir::Value alloca = builder.create<fir::AllocaOp>(loc, ty);
500 builder.createStoreWithConvert(loc, initValue, alloca);
501 return alloca;
502 }
503 // by val
504 return initValue;
505 }
506
507 // check if an allocatable box is unallocated. If so, initialize the boxAlloca
508 // to be unallocated e.g.
509 // %box_alloca = fir.alloca !fir.box<!fir.heap<...>>
510 // %addr = fir.box_addr %box
511 // if (%addr == 0) {
512 // %nullbox = fir.embox %addr
513 // fir.store %nullbox to %box_alloca
514 // } else {
515 // // ...
516 // fir.store %something to %box_alloca
517 // }
518 // omp.yield %box_alloca
519 mlir::Value blockArg =
520 builder.loadIfRef(loc, builder.getBlock()->getArgument(0));
521 auto handleNullAllocatable = [&](mlir::Value boxAlloca) -> fir::IfOp {
522 mlir::Value addr = builder.create<fir::BoxAddrOp>(loc, blockArg);
523 mlir::Value isNotAllocated = builder.genIsNullAddr(loc, addr);
524 fir::IfOp ifOp = builder.create<fir::IfOp>(loc, isNotAllocated,
525 /*withElseRegion=*/true);
526 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
527 // just embox the null address and return
528 mlir::Value nullBox = builder.create<fir::EmboxOp>(loc, ty, addr);
529 builder.create<fir::StoreOp>(loc, nullBox, boxAlloca);
530 return ifOp;
531 };
532
533 // all arrays are boxed
534 if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(ty)) {
535 assert(isByRef && "passing boxes by value is unsupported");
536 bool isAllocatable = mlir::isa<fir::HeapType>(boxTy.getEleTy());
537 mlir::Value boxAlloca = builder.create<fir::AllocaOp>(loc, ty);
538 mlir::Type innerTy = fir::unwrapRefType(boxTy.getEleTy());
539 if (fir::isa_trivial(innerTy)) {
540 // boxed non-sequence value e.g. !fir.box<!fir.heap<i32>>
541 if (!isAllocatable)
542 TODO(loc, "Reduction of non-allocatable trivial typed box");
543
544 fir::IfOp ifUnallocated = handleNullAllocatable(boxAlloca);
545
546 builder.setInsertionPointToStart(&ifUnallocated.getElseRegion().front());
547 mlir::Value valAlloc = builder.create<fir::AllocMemOp>(loc, innerTy);
548 builder.createStoreWithConvert(loc, initValue, valAlloc);
549 mlir::Value box = builder.create<fir::EmboxOp>(loc, ty, valAlloc);
550 builder.create<fir::StoreOp>(loc, box, boxAlloca);
551
552 auto insPt = builder.saveInsertionPoint();
553 createReductionCleanupRegion(builder, loc, reductionDecl);
554 builder.restoreInsertionPoint(insPt);
555 builder.setInsertionPointAfter(ifUnallocated);
556 return boxAlloca;
557 }
558 innerTy = fir::extractSequenceType(boxTy);
559 if (!mlir::isa<fir::SequenceType>(innerTy))
560 TODO(loc, "Unsupported boxed type for reduction");
561
562 fir::IfOp ifUnallocated{nullptr};
563 if (isAllocatable) {
564 ifUnallocated = handleNullAllocatable(boxAlloca);
565 builder.setInsertionPointToStart(&ifUnallocated.getElseRegion().front());
566 }
567
568 // Create the private copy from the initial fir.box:
569 mlir::Value loadedBox = builder.loadIfRef(loc, blockArg);
570 hlfir::Entity source = hlfir::Entity{loadedBox};
571
572 // Allocating on the heap in case the whole reduction is nested inside of a
573 // loop
574 // TODO: compare performance here to using allocas - this could be made to
575 // work by inserting stacksave/stackrestore around the reduction in
576 // openmpirbuilder
577 auto [temp, needsDealloc] = createTempFromMold(loc, builder, source);
578 // if needsDealloc isn't statically false, add cleanup region. Always
579 // do this for allocatable boxes because they might have been re-allocated
580 // in the body of the loop/parallel region
581
582 std::optional<int64_t> cstNeedsDealloc =
583 fir::getIntIfConstant(needsDealloc);
584 assert(cstNeedsDealloc.has_value() &&
585 "createTempFromMold decides this statically");
586 if (cstNeedsDealloc.has_value() && *cstNeedsDealloc != false) {
587 mlir::OpBuilder::InsertionGuard guard(builder);
588 createReductionCleanupRegion(builder, loc, reductionDecl);
589 } else {
590 assert(!isAllocatable && "Allocatable arrays must be heap allocated");
591 }
592
593 // Put the temporary inside of a box:
594 // hlfir::genVariableBox doesn't handle non-default lower bounds
595 mlir::Value box;
596 fir::ShapeShiftOp shapeShift = getShapeShift(builder, loc, loadedBox);
597 mlir::Type boxType = loadedBox.getType();
598 if (mlir::isa<fir::BaseBoxType>(temp.getType()))
599 // the box created by the declare form createTempFromMold is missing lower
600 // bounds info
601 box = builder.create<fir::ReboxOp>(loc, boxType, temp, shapeShift,
602 /*shift=*/mlir::Value{});
603 else
604 box = builder.create<fir::EmboxOp>(
605 loc, boxType, temp, shapeShift,
606 /*slice=*/mlir::Value{},
607 /*typeParams=*/llvm::ArrayRef<mlir::Value>{});
608
609 builder.create<hlfir::AssignOp>(loc, initValue, box);
610 builder.create<fir::StoreOp>(loc, box, boxAlloca);
611 if (ifUnallocated)
612 builder.setInsertionPointAfter(ifUnallocated);
613 return boxAlloca;
614 }
615
616 TODO(loc, "createReductionInitRegion for unsupported type");
617}
618
619mlir::omp::DeclareReductionOp ReductionProcessor::createDeclareReduction(
620 fir::FirOpBuilder &builder, llvm::StringRef reductionOpName,
621 const ReductionIdentifier redId, mlir::Type type, mlir::Location loc,
622 bool isByRef) {
623 mlir::OpBuilder::InsertionGuard guard(builder);
624 mlir::ModuleOp module = builder.getModule();
625
626 assert(!reductionOpName.empty());
627
628 auto decl =
629 module.lookupSymbol<mlir::omp::DeclareReductionOp>(reductionOpName);
630 if (decl)
631 return decl;
632
633 mlir::OpBuilder modBuilder(module.getBodyRegion());
634 mlir::Type valTy = fir::unwrapRefType(type);
635 if (!isByRef)
636 type = valTy;
637
638 decl = modBuilder.create<mlir::omp::DeclareReductionOp>(loc, reductionOpName,
639 type);
640 builder.createBlock(&decl.getInitializerRegion(),
641 decl.getInitializerRegion().end(), {type}, {loc});
642 builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
643
644 mlir::Value init =
645 createReductionInitRegion(builder, loc, decl, redId, type, isByRef);
646 builder.create<mlir::omp::YieldOp>(loc, init);
647
648 builder.createBlock(&decl.getReductionRegion(),
649 decl.getReductionRegion().end(), {type, type},
650 {loc, loc});
651
652 builder.setInsertionPointToEnd(&decl.getReductionRegion().back());
653 mlir::Value op1 = decl.getReductionRegion().front().getArgument(0);
654 mlir::Value op2 = decl.getReductionRegion().front().getArgument(1);
655 genCombiner(builder, loc, redId, type, op1, op2, isByRef);
656
657 return decl;
658}
659
660// TODO: By-ref vs by-val reductions are currently toggled for the whole
661// operation (possibly effecting multiple reduction variables).
662// This could cause a problem with openmp target reductions because
663// by-ref trivial types may not be supported.
664bool ReductionProcessor::doReductionByRef(
665 const llvm::SmallVectorImpl<mlir::Value> &reductionVars) {
666 if (reductionVars.empty())
667 return false;
668 if (forceByrefReduction)
669 return true;
670
671 for (mlir::Value reductionVar : reductionVars) {
672 if (auto declare =
673 mlir::dyn_cast<hlfir::DeclareOp>(reductionVar.getDefiningOp()))
674 reductionVar = declare.getMemref();
675
676 if (!fir::isa_trivial(fir::unwrapRefType(reductionVar.getType())))
677 return true;
678 }
679 return false;
680}
681
682void ReductionProcessor::addDeclareReduction(
683 mlir::Location currentLocation,
684 Fortran::lower::AbstractConverter &converter,
685 const omp::clause::Reduction &reduction,
686 llvm::SmallVectorImpl<mlir::Value> &reductionVars,
687 llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
688 llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
689 *reductionSymbols) {
690 fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
691
692 if (std::get<std::optional<omp::clause::Reduction::ReductionModifier>>(
693 reduction.t))
694 TODO(currentLocation, "Reduction modifiers are not supported");
695
696 mlir::omp::DeclareReductionOp decl;
697 const auto &redOperatorList{
698 std::get<omp::clause::Reduction::ReductionIdentifiers>(reduction.t)};
699 assert(redOperatorList.size() == 1 && "Expecting single operator");
700 const auto &redOperator = redOperatorList.front();
701 const auto &objectList{std::get<omp::ObjectList>(reduction.t)};
702
703 if (!std::holds_alternative<omp::clause::DefinedOperator>(redOperator.u)) {
704 if (const auto *reductionIntrinsic =
705 std::get_if<omp::clause::ProcedureDesignator>(&redOperator.u)) {
706 if (!ReductionProcessor::supportedIntrinsicProcReduction(
707 *reductionIntrinsic)) {
708 return;
709 }
710 } else {
711 return;
712 }
713 }
714
715 // initial pass to collect all reduction vars so we can figure out if this
716 // should happen byref
717 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
718 for (const Object &object : objectList) {
719 const Fortran::semantics::Symbol *symbol = object.id();
720 if (reductionSymbols)
721 reductionSymbols->push_back(symbol);
722 mlir::Value symVal = converter.getSymbolAddress(*symbol);
723 mlir::Type eleType;
724 auto refType = mlir::dyn_cast_or_null<fir::ReferenceType>(symVal.getType());
725 if (refType)
726 eleType = refType.getEleTy();
727 else
728 eleType = symVal.getType();
729
730 // all arrays must be boxed so that we have convenient access to all the
731 // information needed to iterate over the array
732 if (mlir::isa<fir::SequenceType>(eleType)) {
733 // For Host associated symbols, use `SymbolBox` instead
734 Fortran::lower::SymbolBox symBox =
735 converter.lookupOneLevelUpSymbol(*symbol);
736 hlfir::Entity entity{symBox.getAddr()};
737 entity = genVariableBox(currentLocation, builder, entity);
738 mlir::Value box = entity.getBase();
739
740 // Always pass the box by reference so that the OpenMP dialect
741 // verifiers don't need to know anything about fir.box
742 auto alloca =
743 builder.create<fir::AllocaOp>(currentLocation, box.getType());
744 builder.create<fir::StoreOp>(currentLocation, box, alloca);
745
746 symVal = alloca;
747 } else if (mlir::isa<fir::BaseBoxType>(symVal.getType())) {
748 // boxed arrays are passed as values not by reference. Unfortunately,
749 // we can't pass a box by value to omp.redution_declare, so turn it
750 // into a reference
751
752 auto alloca =
753 builder.create<fir::AllocaOp>(currentLocation, symVal.getType());
754 builder.create<fir::StoreOp>(currentLocation, symVal, alloca);
755 symVal = alloca;
756 } else if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) {
757 symVal = declOp.getBase();
758 }
759
760 // this isn't the same as the by-val and by-ref passing later in the
761 // pipeline. Both styles assume that the variable is a reference at
762 // this point
763 assert(mlir::isa<fir::ReferenceType>(symVal.getType()) &&
764 "reduction input var is a reference");
765
766 reductionVars.push_back(symVal);
767 }
768 const bool isByRef = doReductionByRef(reductionVars);
769
770 if (const auto &redDefinedOp =
771 std::get_if<omp::clause::DefinedOperator>(&redOperator.u)) {
772 const auto &intrinsicOp{
773 std::get<omp::clause::DefinedOperator::IntrinsicOperator>(
774 redDefinedOp->u)};
775 ReductionIdentifier redId = getReductionType(intrinsicOp);
776 switch (redId) {
777 case ReductionIdentifier::ADD:
778 case ReductionIdentifier::MULTIPLY:
779 case ReductionIdentifier::AND:
780 case ReductionIdentifier::EQV:
781 case ReductionIdentifier::OR:
782 case ReductionIdentifier::NEQV:
783 break;
784 default:
785 TODO(currentLocation,
786 "Reduction of some intrinsic operators is not supported");
787 break;
788 }
789
790 for (mlir::Value symVal : reductionVars) {
791 auto redType = mlir::cast<fir::ReferenceType>(symVal.getType());
792 const auto &kindMap = firOpBuilder.getKindMap();
793 if (redType.getEleTy().isa<fir::LogicalType>())
794 decl = createDeclareReduction(firOpBuilder,
795 getReductionName(intrinsicOp, kindMap,
796 firOpBuilder.getI1Type(),
797 isByRef),
798 redId, redType, currentLocation, isByRef);
799 else
800 decl = createDeclareReduction(
801 firOpBuilder,
802 getReductionName(intrinsicOp, kindMap, redType, isByRef), redId,
803 redType, currentLocation, isByRef);
804 reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
805 firOpBuilder.getContext(), decl.getSymName()));
806 }
807 } else if (const auto *reductionIntrinsic =
808 std::get_if<omp::clause::ProcedureDesignator>(
809 &redOperator.u)) {
810 if (ReductionProcessor::supportedIntrinsicProcReduction(
811 *reductionIntrinsic)) {
812 ReductionProcessor::ReductionIdentifier redId =
813 ReductionProcessor::getReductionType(*reductionIntrinsic);
814 for (const Object &object : objectList) {
815 const Fortran::semantics::Symbol *symbol = object.id();
816 mlir::Value symVal = converter.getSymbolAddress(*symbol);
817 if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
818 symVal = declOp.getBase();
819 auto redType = symVal.getType().cast<fir::ReferenceType>();
820 if (!redType.getEleTy().isIntOrIndexOrFloat())
821 TODO(currentLocation, "User Defined Reduction on non-trivial type");
822 decl = createDeclareReduction(
823 firOpBuilder,
824 getReductionName(getRealName(*reductionIntrinsic).ToString(),
825 firOpBuilder.getKindMap(), redType, isByRef),
826 redId, redType, currentLocation, isByRef);
827 reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get(
828 firOpBuilder.getContext(), decl.getSymName()));
829 }
830 }
831 }
832}
833
834const Fortran::semantics::SourceName
835ReductionProcessor::getRealName(const Fortran::semantics::Symbol *symbol) {
836 return symbol->GetUltimate().name();
837}
838
839const Fortran::semantics::SourceName
840ReductionProcessor::getRealName(const omp::clause::ProcedureDesignator &pd) {
841 return getRealName(pd.v.id());
842}
843
844int ReductionProcessor::getOperationIdentity(ReductionIdentifier redId,
845 mlir::Location loc) {
846 switch (redId) {
847 case ReductionIdentifier::ADD:
848 case ReductionIdentifier::OR:
849 case ReductionIdentifier::NEQV:
850 return 0;
851 case ReductionIdentifier::MULTIPLY:
852 case ReductionIdentifier::AND:
853 case ReductionIdentifier::EQV:
854 return 1;
855 default:
856 TODO(loc, "Reduction of some intrinsic operators is not supported");
857 }
858}
859
860} // namespace omp
861} // namespace lower
862} // namespace Fortran
863

source code of flang/lib/Lower/OpenMP/ReductionProcessor.cpp