1 | //===-- ReductionProcessor.cpp ----------------------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "ReductionProcessor.h" |
14 | |
15 | #include "flang/Lower/AbstractConverter.h" |
16 | #include "flang/Lower/ConvertType.h" |
17 | #include "flang/Lower/SymbolMap.h" |
18 | #include "flang/Optimizer/Builder/Complex.h" |
19 | #include "flang/Optimizer/Builder/HLFIRTools.h" |
20 | #include "flang/Optimizer/Builder/Todo.h" |
21 | #include "flang/Optimizer/Dialect/FIRType.h" |
22 | #include "flang/Optimizer/HLFIR/HLFIROps.h" |
23 | #include "flang/Optimizer/Support/FatalError.h" |
24 | #include "flang/Parser/tools.h" |
25 | #include "mlir/Dialect/OpenMP/OpenMPDialect.h" |
26 | #include "llvm/Support/CommandLine.h" |
27 | |
28 | static llvm::cl::opt<bool> forceByrefReduction( |
29 | "force-byref-reduction" , |
30 | llvm::cl::desc("Pass all reduction arguments by reference" ), |
31 | llvm::cl::Hidden); |
32 | |
33 | namespace Fortran { |
34 | namespace lower { |
35 | namespace omp { |
36 | |
37 | ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType( |
38 | const omp::clause::ProcedureDesignator &pd) { |
39 | auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>( |
40 | getRealName(pd.v.id()).ToString()) |
41 | .Case("max" , ReductionIdentifier::MAX) |
42 | .Case("min" , ReductionIdentifier::MIN) |
43 | .Case("iand" , ReductionIdentifier::IAND) |
44 | .Case("ior" , ReductionIdentifier::IOR) |
45 | .Case("ieor" , ReductionIdentifier::IEOR) |
46 | .Default(std::nullopt); |
47 | assert(redType && "Invalid Reduction" ); |
48 | return *redType; |
49 | } |
50 | |
51 | ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType( |
52 | omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp) { |
53 | switch (intrinsicOp) { |
54 | case omp::clause::DefinedOperator::IntrinsicOperator::Add: |
55 | return ReductionIdentifier::ADD; |
56 | case omp::clause::DefinedOperator::IntrinsicOperator::Subtract: |
57 | return ReductionIdentifier::SUBTRACT; |
58 | case omp::clause::DefinedOperator::IntrinsicOperator::Multiply: |
59 | return ReductionIdentifier::MULTIPLY; |
60 | case omp::clause::DefinedOperator::IntrinsicOperator::AND: |
61 | return ReductionIdentifier::AND; |
62 | case omp::clause::DefinedOperator::IntrinsicOperator::EQV: |
63 | return ReductionIdentifier::EQV; |
64 | case omp::clause::DefinedOperator::IntrinsicOperator::OR: |
65 | return ReductionIdentifier::OR; |
66 | case omp::clause::DefinedOperator::IntrinsicOperator::NEQV: |
67 | return ReductionIdentifier::NEQV; |
68 | default: |
69 | llvm_unreachable("unexpected intrinsic operator in reduction" ); |
70 | } |
71 | } |
72 | |
73 | bool ReductionProcessor::supportedIntrinsicProcReduction( |
74 | const omp::clause::ProcedureDesignator &pd) { |
75 | Fortran::semantics::Symbol *sym = pd.v.id(); |
76 | if (!sym->GetUltimate().attrs().test(Fortran::semantics::Attr::INTRINSIC)) |
77 | return false; |
78 | auto redType = llvm::StringSwitch<bool>(getRealName(sym).ToString()) |
79 | .Case("max" , true) |
80 | .Case("min" , true) |
81 | .Case("iand" , true) |
82 | .Case("ior" , true) |
83 | .Case("ieor" , true) |
84 | .Default(false); |
85 | return redType; |
86 | } |
87 | |
88 | std::string |
89 | ReductionProcessor::getReductionName(llvm::StringRef name, |
90 | const fir::KindMapping &kindMap, |
91 | mlir::Type ty, bool isByRef) { |
92 | ty = fir::unwrapRefType(ty); |
93 | |
94 | // extra string to distinguish reduction functions for variables passed by |
95 | // reference |
96 | llvm::StringRef byrefAddition{"" }; |
97 | if (isByRef) |
98 | byrefAddition = "_byref" ; |
99 | |
100 | return fir::getTypeAsString(ty, kindMap, (name + byrefAddition).str()); |
101 | } |
102 | |
103 | std::string ReductionProcessor::getReductionName( |
104 | omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp, |
105 | const fir::KindMapping &kindMap, mlir::Type ty, bool isByRef) { |
106 | std::string reductionName; |
107 | |
108 | switch (intrinsicOp) { |
109 | case omp::clause::DefinedOperator::IntrinsicOperator::Add: |
110 | reductionName = "add_reduction" ; |
111 | break; |
112 | case omp::clause::DefinedOperator::IntrinsicOperator::Multiply: |
113 | reductionName = "multiply_reduction" ; |
114 | break; |
115 | case omp::clause::DefinedOperator::IntrinsicOperator::AND: |
116 | return "and_reduction" ; |
117 | case omp::clause::DefinedOperator::IntrinsicOperator::EQV: |
118 | return "eqv_reduction" ; |
119 | case omp::clause::DefinedOperator::IntrinsicOperator::OR: |
120 | return "or_reduction" ; |
121 | case omp::clause::DefinedOperator::IntrinsicOperator::NEQV: |
122 | return "neqv_reduction" ; |
123 | default: |
124 | reductionName = "other_reduction" ; |
125 | break; |
126 | } |
127 | |
128 | return getReductionName(reductionName, kindMap, ty, isByRef); |
129 | } |
130 | |
131 | mlir::Value |
132 | ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type, |
133 | ReductionIdentifier redId, |
134 | fir::FirOpBuilder &builder) { |
135 | type = fir::unwrapRefType(type); |
136 | if (!fir::isa_integer(type) && !fir::isa_real(type) && |
137 | !fir::isa_complex(type) && !mlir::isa<fir::LogicalType>(type)) |
138 | TODO(loc, "Reduction of some types is not supported" ); |
139 | switch (redId) { |
140 | case ReductionIdentifier::MAX: { |
141 | if (auto ty = type.dyn_cast<mlir::FloatType>()) { |
142 | const llvm::fltSemantics &sem = ty.getFloatSemantics(); |
143 | return builder.createRealConstant( |
144 | loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true)); |
145 | } |
146 | unsigned bits = type.getIntOrFloatBitWidth(); |
147 | int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue(); |
148 | return builder.createIntegerConstant(loc, type, minInt); |
149 | } |
150 | case ReductionIdentifier::MIN: { |
151 | if (auto ty = type.dyn_cast<mlir::FloatType>()) { |
152 | const llvm::fltSemantics &sem = ty.getFloatSemantics(); |
153 | return builder.createRealConstant( |
154 | loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/false)); |
155 | } |
156 | unsigned bits = type.getIntOrFloatBitWidth(); |
157 | int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue(); |
158 | return builder.createIntegerConstant(loc, type, maxInt); |
159 | } |
160 | case ReductionIdentifier::IOR: { |
161 | unsigned bits = type.getIntOrFloatBitWidth(); |
162 | int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue(); |
163 | return builder.createIntegerConstant(loc, type, zeroInt); |
164 | } |
165 | case ReductionIdentifier::IEOR: { |
166 | unsigned bits = type.getIntOrFloatBitWidth(); |
167 | int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue(); |
168 | return builder.createIntegerConstant(loc, type, zeroInt); |
169 | } |
170 | case ReductionIdentifier::IAND: { |
171 | unsigned bits = type.getIntOrFloatBitWidth(); |
172 | int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue(); |
173 | return builder.createIntegerConstant(loc, type, allOnInt); |
174 | } |
175 | case ReductionIdentifier::ADD: |
176 | case ReductionIdentifier::MULTIPLY: |
177 | case ReductionIdentifier::AND: |
178 | case ReductionIdentifier::OR: |
179 | case ReductionIdentifier::EQV: |
180 | case ReductionIdentifier::NEQV: |
181 | if (auto cplxTy = mlir::dyn_cast<fir::ComplexType>(type)) { |
182 | mlir::Type realTy = |
183 | Fortran::lower::convertReal(builder.getContext(), cplxTy.getFKind()); |
184 | mlir::Value initRe = builder.createRealConstant( |
185 | loc, realTy, getOperationIdentity(redId, loc)); |
186 | mlir::Value initIm = builder.createRealConstant(loc, realTy, 0); |
187 | |
188 | return fir::factory::Complex{builder, loc}.createComplex(type, initRe, |
189 | initIm); |
190 | } |
191 | if (type.isa<mlir::FloatType>()) |
192 | return builder.create<mlir::arith::ConstantOp>( |
193 | loc, type, |
194 | builder.getFloatAttr(type, (double)getOperationIdentity(redId, loc))); |
195 | |
196 | if (type.isa<fir::LogicalType>()) { |
197 | mlir::Value intConst = builder.create<mlir::arith::ConstantOp>( |
198 | loc, builder.getI1Type(), |
199 | builder.getIntegerAttr(builder.getI1Type(), |
200 | getOperationIdentity(redId, loc))); |
201 | return builder.createConvert(loc, type, intConst); |
202 | } |
203 | |
204 | return builder.create<mlir::arith::ConstantOp>( |
205 | loc, type, |
206 | builder.getIntegerAttr(type, getOperationIdentity(redId, loc))); |
207 | case ReductionIdentifier::ID: |
208 | case ReductionIdentifier::USER_DEF_OP: |
209 | case ReductionIdentifier::SUBTRACT: |
210 | TODO(loc, "Reduction of some identifier types is not supported" ); |
211 | } |
212 | llvm_unreachable("Unhandled Reduction identifier : getReductionInitValue" ); |
213 | } |
214 | |
215 | mlir::Value ReductionProcessor::createScalarCombiner( |
216 | fir::FirOpBuilder &builder, mlir::Location loc, ReductionIdentifier redId, |
217 | mlir::Type type, mlir::Value op1, mlir::Value op2) { |
218 | mlir::Value reductionOp; |
219 | type = fir::unwrapRefType(type); |
220 | switch (redId) { |
221 | case ReductionIdentifier::MAX: |
222 | reductionOp = |
223 | getReductionOperation<mlir::arith::MaxNumFOp, mlir::arith::MaxSIOp>( |
224 | builder, type, loc, op1, op2); |
225 | break; |
226 | case ReductionIdentifier::MIN: |
227 | reductionOp = |
228 | getReductionOperation<mlir::arith::MinNumFOp, mlir::arith::MinSIOp>( |
229 | builder, type, loc, op1, op2); |
230 | break; |
231 | case ReductionIdentifier::IOR: |
232 | assert((type.isIntOrIndex()) && "only integer is expected" ); |
233 | reductionOp = builder.create<mlir::arith::OrIOp>(loc, op1, op2); |
234 | break; |
235 | case ReductionIdentifier::IEOR: |
236 | assert((type.isIntOrIndex()) && "only integer is expected" ); |
237 | reductionOp = builder.create<mlir::arith::XOrIOp>(loc, op1, op2); |
238 | break; |
239 | case ReductionIdentifier::IAND: |
240 | assert((type.isIntOrIndex()) && "only integer is expected" ); |
241 | reductionOp = builder.create<mlir::arith::AndIOp>(loc, op1, op2); |
242 | break; |
243 | case ReductionIdentifier::ADD: |
244 | reductionOp = |
245 | getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp, |
246 | fir::AddcOp>(builder, type, loc, op1, op2); |
247 | break; |
248 | case ReductionIdentifier::MULTIPLY: |
249 | reductionOp = |
250 | getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp, |
251 | fir::MulcOp>(builder, type, loc, op1, op2); |
252 | break; |
253 | case ReductionIdentifier::AND: { |
254 | mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); |
255 | mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); |
256 | |
257 | mlir::Value andiOp = builder.create<mlir::arith::AndIOp>(loc, op1I1, op2I1); |
258 | |
259 | reductionOp = builder.createConvert(loc, type, andiOp); |
260 | break; |
261 | } |
262 | case ReductionIdentifier::OR: { |
263 | mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); |
264 | mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); |
265 | |
266 | mlir::Value oriOp = builder.create<mlir::arith::OrIOp>(loc, op1I1, op2I1); |
267 | |
268 | reductionOp = builder.createConvert(loc, type, oriOp); |
269 | break; |
270 | } |
271 | case ReductionIdentifier::EQV: { |
272 | mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); |
273 | mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); |
274 | |
275 | mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>( |
276 | loc, mlir::arith::CmpIPredicate::eq, op1I1, op2I1); |
277 | |
278 | reductionOp = builder.createConvert(loc, type, cmpiOp); |
279 | break; |
280 | } |
281 | case ReductionIdentifier::NEQV: { |
282 | mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); |
283 | mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); |
284 | |
285 | mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>( |
286 | loc, mlir::arith::CmpIPredicate::ne, op1I1, op2I1); |
287 | |
288 | reductionOp = builder.createConvert(loc, type, cmpiOp); |
289 | break; |
290 | } |
291 | default: |
292 | TODO(loc, "Reduction of some intrinsic operators is not supported" ); |
293 | } |
294 | |
295 | return reductionOp; |
296 | } |
297 | |
298 | /// Generate a fir::ShapeShift op describing the provided boxed array. |
299 | static fir::ShapeShiftOp getShapeShift(fir::FirOpBuilder &builder, |
300 | mlir::Location loc, mlir::Value box) { |
301 | fir::SequenceType sequenceType = mlir::cast<fir::SequenceType>( |
302 | hlfir::getFortranElementOrSequenceType(box.getType())); |
303 | const unsigned rank = sequenceType.getDimension(); |
304 | llvm::SmallVector<mlir::Value> lbAndExtents; |
305 | lbAndExtents.reserve(rank * 2); |
306 | |
307 | mlir::Type idxTy = builder.getIndexType(); |
308 | for (unsigned i = 0; i < rank; ++i) { |
309 | // TODO: ideally we want to hoist box reads out of the critical section. |
310 | // We could do this by having box dimensions in block arguments like |
311 | // OpenACC does |
312 | mlir::Value dim = builder.createIntegerConstant(loc, idxTy, i); |
313 | auto dimInfo = |
314 | builder.create<fir::BoxDimsOp>(loc, idxTy, idxTy, idxTy, box, dim); |
315 | lbAndExtents.push_back(dimInfo.getLowerBound()); |
316 | lbAndExtents.push_back(dimInfo.getExtent()); |
317 | } |
318 | |
319 | auto shapeShiftTy = fir::ShapeShiftType::get(builder.getContext(), rank); |
320 | auto shapeShift = |
321 | builder.create<fir::ShapeShiftOp>(loc, shapeShiftTy, lbAndExtents); |
322 | return shapeShift; |
323 | } |
324 | |
325 | /// Create reduction combiner region for reduction variables which are boxed |
326 | /// arrays |
327 | static void genBoxCombiner(fir::FirOpBuilder &builder, mlir::Location loc, |
328 | ReductionProcessor::ReductionIdentifier redId, |
329 | fir::BaseBoxType boxTy, mlir::Value lhs, |
330 | mlir::Value rhs) { |
331 | fir::SequenceType seqTy = mlir::dyn_cast_or_null<fir::SequenceType>( |
332 | fir::unwrapRefType(boxTy.getEleTy())); |
333 | fir::HeapType heapTy = |
334 | mlir::dyn_cast_or_null<fir::HeapType>(boxTy.getEleTy()); |
335 | if ((!seqTy || seqTy.hasUnknownShape()) && !heapTy) |
336 | TODO(loc, "Unsupported boxed type in OpenMP reduction" ); |
337 | |
338 | // load fir.ref<fir.box<...>> |
339 | mlir::Value lhsAddr = lhs; |
340 | lhs = builder.create<fir::LoadOp>(loc, lhs); |
341 | rhs = builder.create<fir::LoadOp>(loc, rhs); |
342 | |
343 | if (heapTy && !seqTy) { |
344 | // get box contents (heap pointers) |
345 | lhs = builder.create<fir::BoxAddrOp>(loc, lhs); |
346 | rhs = builder.create<fir::BoxAddrOp>(loc, rhs); |
347 | mlir::Value lhsValAddr = lhs; |
348 | |
349 | // load heap pointers |
350 | lhs = builder.create<fir::LoadOp>(loc, lhs); |
351 | rhs = builder.create<fir::LoadOp>(loc, rhs); |
352 | |
353 | mlir::Value result = ReductionProcessor::createScalarCombiner( |
354 | builder, loc, redId, heapTy.getEleTy(), lhs, rhs); |
355 | builder.create<fir::StoreOp>(loc, result, lhsValAddr); |
356 | builder.create<mlir::omp::YieldOp>(loc, lhsAddr); |
357 | return; |
358 | } |
359 | |
360 | fir::ShapeShiftOp shapeShift = getShapeShift(builder, loc, lhs); |
361 | |
362 | // Iterate over array elements, applying the equivalent scalar reduction: |
363 | |
364 | // F2018 5.4.10.2: Unallocated allocatable variables may not be referenced |
365 | // and so no null check is needed here before indexing into the (possibly |
366 | // allocatable) arrays. |
367 | |
368 | // A hlfir::elemental here gets inlined with a temporary so create the |
369 | // loop nest directly. |
370 | // This function already controls all of the code in this region so we |
371 | // know this won't miss any opportuinties for clever elemental inlining |
372 | hlfir::LoopNest nest = hlfir::genLoopNest( |
373 | loc, builder, shapeShift.getExtents(), /*isUnordered=*/true); |
374 | builder.setInsertionPointToStart(nest.innerLoop.getBody()); |
375 | mlir::Type refTy = fir::ReferenceType::get(seqTy.getEleTy()); |
376 | auto lhsEleAddr = builder.create<fir::ArrayCoorOp>( |
377 | loc, refTy, lhs, shapeShift, /*slice=*/mlir::Value{}, |
378 | nest.oneBasedIndices, /*typeparms=*/mlir::ValueRange{}); |
379 | auto rhsEleAddr = builder.create<fir::ArrayCoorOp>( |
380 | loc, refTy, rhs, shapeShift, /*slice=*/mlir::Value{}, |
381 | nest.oneBasedIndices, /*typeparms=*/mlir::ValueRange{}); |
382 | auto lhsEle = builder.create<fir::LoadOp>(loc, lhsEleAddr); |
383 | auto rhsEle = builder.create<fir::LoadOp>(loc, rhsEleAddr); |
384 | mlir::Value scalarReduction = ReductionProcessor::createScalarCombiner( |
385 | builder, loc, redId, refTy, lhsEle, rhsEle); |
386 | builder.create<fir::StoreOp>(loc, scalarReduction, lhsEleAddr); |
387 | |
388 | builder.setInsertionPointAfter(nest.outerLoop); |
389 | builder.create<mlir::omp::YieldOp>(loc, lhsAddr); |
390 | } |
391 | |
392 | // generate combiner region for reduction operations |
393 | static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc, |
394 | ReductionProcessor::ReductionIdentifier redId, |
395 | mlir::Type ty, mlir::Value lhs, mlir::Value rhs, |
396 | bool isByRef) { |
397 | ty = fir::unwrapRefType(ty); |
398 | |
399 | if (fir::isa_trivial(ty)) { |
400 | mlir::Value lhsLoaded = builder.loadIfRef(loc, lhs); |
401 | mlir::Value rhsLoaded = builder.loadIfRef(loc, rhs); |
402 | |
403 | mlir::Value result = ReductionProcessor::createScalarCombiner( |
404 | builder, loc, redId, ty, lhsLoaded, rhsLoaded); |
405 | if (isByRef) { |
406 | builder.create<fir::StoreOp>(loc, result, lhs); |
407 | builder.create<mlir::omp::YieldOp>(loc, lhs); |
408 | } else { |
409 | builder.create<mlir::omp::YieldOp>(loc, result); |
410 | } |
411 | return; |
412 | } |
413 | // all arrays should have been boxed |
414 | if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) { |
415 | genBoxCombiner(builder, loc, redId, boxTy, lhs, rhs); |
416 | return; |
417 | } |
418 | |
419 | TODO(loc, "OpenMP genCombiner for unsupported reduction variable type" ); |
420 | } |
421 | |
422 | static void |
423 | createReductionCleanupRegion(fir::FirOpBuilder &builder, mlir::Location loc, |
424 | mlir::omp::DeclareReductionOp &reductionDecl) { |
425 | mlir::Type redTy = reductionDecl.getType(); |
426 | |
427 | mlir::Region &cleanupRegion = reductionDecl.getCleanupRegion(); |
428 | assert(cleanupRegion.empty()); |
429 | mlir::Block *block = |
430 | builder.createBlock(&cleanupRegion, cleanupRegion.end(), {redTy}, {loc}); |
431 | builder.setInsertionPointToEnd(block); |
432 | |
433 | auto typeError = [loc]() { |
434 | fir::emitFatalError(loc, |
435 | "Attempt to create an omp reduction cleanup region " |
436 | "for a type that wasn't allocated" , |
437 | /*genCrashDiag=*/true); |
438 | }; |
439 | |
440 | mlir::Type valTy = fir::unwrapRefType(redTy); |
441 | if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(valTy)) { |
442 | if (!mlir::isa<fir::HeapType>(boxTy.getEleTy())) { |
443 | mlir::Type innerTy = fir::extractSequenceType(boxTy); |
444 | if (!mlir::isa<fir::SequenceType>(innerTy)) |
445 | typeError(); |
446 | } |
447 | |
448 | mlir::Value arg = block->getArgument(0); |
449 | arg = builder.loadIfRef(loc, arg); |
450 | assert(mlir::isa<fir::BaseBoxType>(arg.getType())); |
451 | |
452 | // Deallocate box |
453 | // The FIR type system doesn't nesecarrily know that this is a mutable box |
454 | // if we allocated the thread local array on the heap to avoid looped stack |
455 | // allocations. |
456 | mlir::Value addr = |
457 | hlfir::genVariableRawAddress(loc, builder, hlfir::Entity{arg}); |
458 | mlir::Value isAllocated = builder.genIsNotNullAddr(loc, addr); |
459 | fir::IfOp ifOp = |
460 | builder.create<fir::IfOp>(loc, isAllocated, /*withElseRegion=*/false); |
461 | builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
462 | |
463 | mlir::Value cast = builder.createConvert( |
464 | loc, fir::HeapType::get(fir::dyn_cast_ptrEleTy(addr.getType())), addr); |
465 | builder.create<fir::FreeMemOp>(loc, cast); |
466 | |
467 | builder.setInsertionPointAfter(ifOp); |
468 | builder.create<mlir::omp::YieldOp>(loc); |
469 | return; |
470 | } |
471 | |
472 | typeError(); |
473 | } |
474 | |
475 | // like fir::unwrapSeqOrBoxedSeqType except it also works for non-sequence boxes |
476 | static mlir::Type unwrapSeqOrBoxedType(mlir::Type ty) { |
477 | if (auto seqTy = ty.dyn_cast<fir::SequenceType>()) |
478 | return seqTy.getEleTy(); |
479 | if (auto boxTy = ty.dyn_cast<fir::BaseBoxType>()) { |
480 | auto eleTy = fir::unwrapRefType(boxTy.getEleTy()); |
481 | if (auto seqTy = eleTy.dyn_cast<fir::SequenceType>()) |
482 | return seqTy.getEleTy(); |
483 | return eleTy; |
484 | } |
485 | return ty; |
486 | } |
487 | |
488 | static mlir::Value |
489 | createReductionInitRegion(fir::FirOpBuilder &builder, mlir::Location loc, |
490 | mlir::omp::DeclareReductionOp &reductionDecl, |
491 | const ReductionProcessor::ReductionIdentifier redId, |
492 | mlir::Type type, bool isByRef) { |
493 | mlir::Type ty = fir::unwrapRefType(type); |
494 | mlir::Value initValue = ReductionProcessor::getReductionInitValue( |
495 | loc, unwrapSeqOrBoxedType(ty), redId, builder); |
496 | |
497 | if (fir::isa_trivial(ty)) { |
498 | if (isByRef) { |
499 | mlir::Value alloca = builder.create<fir::AllocaOp>(loc, ty); |
500 | builder.createStoreWithConvert(loc, initValue, alloca); |
501 | return alloca; |
502 | } |
503 | // by val |
504 | return initValue; |
505 | } |
506 | |
507 | // check if an allocatable box is unallocated. If so, initialize the boxAlloca |
508 | // to be unallocated e.g. |
509 | // %box_alloca = fir.alloca !fir.box<!fir.heap<...>> |
510 | // %addr = fir.box_addr %box |
511 | // if (%addr == 0) { |
512 | // %nullbox = fir.embox %addr |
513 | // fir.store %nullbox to %box_alloca |
514 | // } else { |
515 | // // ... |
516 | // fir.store %something to %box_alloca |
517 | // } |
518 | // omp.yield %box_alloca |
519 | mlir::Value blockArg = |
520 | builder.loadIfRef(loc, builder.getBlock()->getArgument(0)); |
521 | auto handleNullAllocatable = [&](mlir::Value boxAlloca) -> fir::IfOp { |
522 | mlir::Value addr = builder.create<fir::BoxAddrOp>(loc, blockArg); |
523 | mlir::Value isNotAllocated = builder.genIsNullAddr(loc, addr); |
524 | fir::IfOp ifOp = builder.create<fir::IfOp>(loc, isNotAllocated, |
525 | /*withElseRegion=*/true); |
526 | builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); |
527 | // just embox the null address and return |
528 | mlir::Value nullBox = builder.create<fir::EmboxOp>(loc, ty, addr); |
529 | builder.create<fir::StoreOp>(loc, nullBox, boxAlloca); |
530 | return ifOp; |
531 | }; |
532 | |
533 | // all arrays are boxed |
534 | if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(ty)) { |
535 | assert(isByRef && "passing boxes by value is unsupported" ); |
536 | bool isAllocatable = mlir::isa<fir::HeapType>(boxTy.getEleTy()); |
537 | mlir::Value boxAlloca = builder.create<fir::AllocaOp>(loc, ty); |
538 | mlir::Type innerTy = fir::unwrapRefType(boxTy.getEleTy()); |
539 | if (fir::isa_trivial(innerTy)) { |
540 | // boxed non-sequence value e.g. !fir.box<!fir.heap<i32>> |
541 | if (!isAllocatable) |
542 | TODO(loc, "Reduction of non-allocatable trivial typed box" ); |
543 | |
544 | fir::IfOp ifUnallocated = handleNullAllocatable(boxAlloca); |
545 | |
546 | builder.setInsertionPointToStart(&ifUnallocated.getElseRegion().front()); |
547 | mlir::Value valAlloc = builder.create<fir::AllocMemOp>(loc, innerTy); |
548 | builder.createStoreWithConvert(loc, initValue, valAlloc); |
549 | mlir::Value box = builder.create<fir::EmboxOp>(loc, ty, valAlloc); |
550 | builder.create<fir::StoreOp>(loc, box, boxAlloca); |
551 | |
552 | auto insPt = builder.saveInsertionPoint(); |
553 | createReductionCleanupRegion(builder, loc, reductionDecl); |
554 | builder.restoreInsertionPoint(insPt); |
555 | builder.setInsertionPointAfter(ifUnallocated); |
556 | return boxAlloca; |
557 | } |
558 | innerTy = fir::extractSequenceType(boxTy); |
559 | if (!mlir::isa<fir::SequenceType>(innerTy)) |
560 | TODO(loc, "Unsupported boxed type for reduction" ); |
561 | |
562 | fir::IfOp ifUnallocated{nullptr}; |
563 | if (isAllocatable) { |
564 | ifUnallocated = handleNullAllocatable(boxAlloca); |
565 | builder.setInsertionPointToStart(&ifUnallocated.getElseRegion().front()); |
566 | } |
567 | |
568 | // Create the private copy from the initial fir.box: |
569 | mlir::Value loadedBox = builder.loadIfRef(loc, blockArg); |
570 | hlfir::Entity source = hlfir::Entity{loadedBox}; |
571 | |
572 | // Allocating on the heap in case the whole reduction is nested inside of a |
573 | // loop |
574 | // TODO: compare performance here to using allocas - this could be made to |
575 | // work by inserting stacksave/stackrestore around the reduction in |
576 | // openmpirbuilder |
577 | auto [temp, needsDealloc] = createTempFromMold(loc, builder, source); |
578 | // if needsDealloc isn't statically false, add cleanup region. Always |
579 | // do this for allocatable boxes because they might have been re-allocated |
580 | // in the body of the loop/parallel region |
581 | |
582 | std::optional<int64_t> cstNeedsDealloc = |
583 | fir::getIntIfConstant(needsDealloc); |
584 | assert(cstNeedsDealloc.has_value() && |
585 | "createTempFromMold decides this statically" ); |
586 | if (cstNeedsDealloc.has_value() && *cstNeedsDealloc != false) { |
587 | mlir::OpBuilder::InsertionGuard guard(builder); |
588 | createReductionCleanupRegion(builder, loc, reductionDecl); |
589 | } else { |
590 | assert(!isAllocatable && "Allocatable arrays must be heap allocated" ); |
591 | } |
592 | |
593 | // Put the temporary inside of a box: |
594 | // hlfir::genVariableBox doesn't handle non-default lower bounds |
595 | mlir::Value box; |
596 | fir::ShapeShiftOp shapeShift = getShapeShift(builder, loc, loadedBox); |
597 | mlir::Type boxType = loadedBox.getType(); |
598 | if (mlir::isa<fir::BaseBoxType>(temp.getType())) |
599 | // the box created by the declare form createTempFromMold is missing lower |
600 | // bounds info |
601 | box = builder.create<fir::ReboxOp>(loc, boxType, temp, shapeShift, |
602 | /*shift=*/mlir::Value{}); |
603 | else |
604 | box = builder.create<fir::EmboxOp>( |
605 | loc, boxType, temp, shapeShift, |
606 | /*slice=*/mlir::Value{}, |
607 | /*typeParams=*/llvm::ArrayRef<mlir::Value>{}); |
608 | |
609 | builder.create<hlfir::AssignOp>(loc, initValue, box); |
610 | builder.create<fir::StoreOp>(loc, box, boxAlloca); |
611 | if (ifUnallocated) |
612 | builder.setInsertionPointAfter(ifUnallocated); |
613 | return boxAlloca; |
614 | } |
615 | |
616 | TODO(loc, "createReductionInitRegion for unsupported type" ); |
617 | } |
618 | |
619 | mlir::omp::DeclareReductionOp ReductionProcessor::createDeclareReduction( |
620 | fir::FirOpBuilder &builder, llvm::StringRef reductionOpName, |
621 | const ReductionIdentifier redId, mlir::Type type, mlir::Location loc, |
622 | bool isByRef) { |
623 | mlir::OpBuilder::InsertionGuard guard(builder); |
624 | mlir::ModuleOp module = builder.getModule(); |
625 | |
626 | assert(!reductionOpName.empty()); |
627 | |
628 | auto decl = |
629 | module.lookupSymbol<mlir::omp::DeclareReductionOp>(reductionOpName); |
630 | if (decl) |
631 | return decl; |
632 | |
633 | mlir::OpBuilder modBuilder(module.getBodyRegion()); |
634 | mlir::Type valTy = fir::unwrapRefType(type); |
635 | if (!isByRef) |
636 | type = valTy; |
637 | |
638 | decl = modBuilder.create<mlir::omp::DeclareReductionOp>(loc, reductionOpName, |
639 | type); |
640 | builder.createBlock(&decl.getInitializerRegion(), |
641 | decl.getInitializerRegion().end(), {type}, {loc}); |
642 | builder.setInsertionPointToEnd(&decl.getInitializerRegion().back()); |
643 | |
644 | mlir::Value init = |
645 | createReductionInitRegion(builder, loc, decl, redId, type, isByRef); |
646 | builder.create<mlir::omp::YieldOp>(loc, init); |
647 | |
648 | builder.createBlock(&decl.getReductionRegion(), |
649 | decl.getReductionRegion().end(), {type, type}, |
650 | {loc, loc}); |
651 | |
652 | builder.setInsertionPointToEnd(&decl.getReductionRegion().back()); |
653 | mlir::Value op1 = decl.getReductionRegion().front().getArgument(0); |
654 | mlir::Value op2 = decl.getReductionRegion().front().getArgument(1); |
655 | genCombiner(builder, loc, redId, type, op1, op2, isByRef); |
656 | |
657 | return decl; |
658 | } |
659 | |
660 | // TODO: By-ref vs by-val reductions are currently toggled for the whole |
661 | // operation (possibly effecting multiple reduction variables). |
662 | // This could cause a problem with openmp target reductions because |
663 | // by-ref trivial types may not be supported. |
664 | bool ReductionProcessor::doReductionByRef( |
665 | const llvm::SmallVectorImpl<mlir::Value> &reductionVars) { |
666 | if (reductionVars.empty()) |
667 | return false; |
668 | if (forceByrefReduction) |
669 | return true; |
670 | |
671 | for (mlir::Value reductionVar : reductionVars) { |
672 | if (auto declare = |
673 | mlir::dyn_cast<hlfir::DeclareOp>(reductionVar.getDefiningOp())) |
674 | reductionVar = declare.getMemref(); |
675 | |
676 | if (!fir::isa_trivial(fir::unwrapRefType(reductionVar.getType()))) |
677 | return true; |
678 | } |
679 | return false; |
680 | } |
681 | |
682 | void ReductionProcessor::addDeclareReduction( |
683 | mlir::Location currentLocation, |
684 | Fortran::lower::AbstractConverter &converter, |
685 | const omp::clause::Reduction &reduction, |
686 | llvm::SmallVectorImpl<mlir::Value> &reductionVars, |
687 | llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols, |
688 | llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> |
689 | *reductionSymbols) { |
690 | fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); |
691 | |
692 | if (std::get<std::optional<omp::clause::Reduction::ReductionModifier>>( |
693 | reduction.t)) |
694 | TODO(currentLocation, "Reduction modifiers are not supported" ); |
695 | |
696 | mlir::omp::DeclareReductionOp decl; |
697 | const auto &redOperatorList{ |
698 | std::get<omp::clause::Reduction::ReductionIdentifiers>(reduction.t)}; |
699 | assert(redOperatorList.size() == 1 && "Expecting single operator" ); |
700 | const auto &redOperator = redOperatorList.front(); |
701 | const auto &objectList{std::get<omp::ObjectList>(reduction.t)}; |
702 | |
703 | if (!std::holds_alternative<omp::clause::DefinedOperator>(redOperator.u)) { |
704 | if (const auto *reductionIntrinsic = |
705 | std::get_if<omp::clause::ProcedureDesignator>(&redOperator.u)) { |
706 | if (!ReductionProcessor::supportedIntrinsicProcReduction( |
707 | *reductionIntrinsic)) { |
708 | return; |
709 | } |
710 | } else { |
711 | return; |
712 | } |
713 | } |
714 | |
715 | // initial pass to collect all reduction vars so we can figure out if this |
716 | // should happen byref |
717 | fir::FirOpBuilder &builder = converter.getFirOpBuilder(); |
718 | for (const Object &object : objectList) { |
719 | const Fortran::semantics::Symbol *symbol = object.id(); |
720 | if (reductionSymbols) |
721 | reductionSymbols->push_back(symbol); |
722 | mlir::Value symVal = converter.getSymbolAddress(*symbol); |
723 | mlir::Type eleType; |
724 | auto refType = mlir::dyn_cast_or_null<fir::ReferenceType>(symVal.getType()); |
725 | if (refType) |
726 | eleType = refType.getEleTy(); |
727 | else |
728 | eleType = symVal.getType(); |
729 | |
730 | // all arrays must be boxed so that we have convenient access to all the |
731 | // information needed to iterate over the array |
732 | if (mlir::isa<fir::SequenceType>(eleType)) { |
733 | // For Host associated symbols, use `SymbolBox` instead |
734 | Fortran::lower::SymbolBox symBox = |
735 | converter.lookupOneLevelUpSymbol(*symbol); |
736 | hlfir::Entity entity{symBox.getAddr()}; |
737 | entity = genVariableBox(currentLocation, builder, entity); |
738 | mlir::Value box = entity.getBase(); |
739 | |
740 | // Always pass the box by reference so that the OpenMP dialect |
741 | // verifiers don't need to know anything about fir.box |
742 | auto alloca = |
743 | builder.create<fir::AllocaOp>(currentLocation, box.getType()); |
744 | builder.create<fir::StoreOp>(currentLocation, box, alloca); |
745 | |
746 | symVal = alloca; |
747 | } else if (mlir::isa<fir::BaseBoxType>(symVal.getType())) { |
748 | // boxed arrays are passed as values not by reference. Unfortunately, |
749 | // we can't pass a box by value to omp.redution_declare, so turn it |
750 | // into a reference |
751 | |
752 | auto alloca = |
753 | builder.create<fir::AllocaOp>(currentLocation, symVal.getType()); |
754 | builder.create<fir::StoreOp>(currentLocation, symVal, alloca); |
755 | symVal = alloca; |
756 | } else if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) { |
757 | symVal = declOp.getBase(); |
758 | } |
759 | |
760 | // this isn't the same as the by-val and by-ref passing later in the |
761 | // pipeline. Both styles assume that the variable is a reference at |
762 | // this point |
763 | assert(mlir::isa<fir::ReferenceType>(symVal.getType()) && |
764 | "reduction input var is a reference" ); |
765 | |
766 | reductionVars.push_back(symVal); |
767 | } |
768 | const bool isByRef = doReductionByRef(reductionVars); |
769 | |
770 | if (const auto &redDefinedOp = |
771 | std::get_if<omp::clause::DefinedOperator>(&redOperator.u)) { |
772 | const auto &intrinsicOp{ |
773 | std::get<omp::clause::DefinedOperator::IntrinsicOperator>( |
774 | redDefinedOp->u)}; |
775 | ReductionIdentifier redId = getReductionType(intrinsicOp); |
776 | switch (redId) { |
777 | case ReductionIdentifier::ADD: |
778 | case ReductionIdentifier::MULTIPLY: |
779 | case ReductionIdentifier::AND: |
780 | case ReductionIdentifier::EQV: |
781 | case ReductionIdentifier::OR: |
782 | case ReductionIdentifier::NEQV: |
783 | break; |
784 | default: |
785 | TODO(currentLocation, |
786 | "Reduction of some intrinsic operators is not supported" ); |
787 | break; |
788 | } |
789 | |
790 | for (mlir::Value symVal : reductionVars) { |
791 | auto redType = mlir::cast<fir::ReferenceType>(symVal.getType()); |
792 | const auto &kindMap = firOpBuilder.getKindMap(); |
793 | if (redType.getEleTy().isa<fir::LogicalType>()) |
794 | decl = createDeclareReduction(firOpBuilder, |
795 | getReductionName(intrinsicOp, kindMap, |
796 | firOpBuilder.getI1Type(), |
797 | isByRef), |
798 | redId, redType, currentLocation, isByRef); |
799 | else |
800 | decl = createDeclareReduction( |
801 | firOpBuilder, |
802 | getReductionName(intrinsicOp, kindMap, redType, isByRef), redId, |
803 | redType, currentLocation, isByRef); |
804 | reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( |
805 | firOpBuilder.getContext(), decl.getSymName())); |
806 | } |
807 | } else if (const auto *reductionIntrinsic = |
808 | std::get_if<omp::clause::ProcedureDesignator>( |
809 | &redOperator.u)) { |
810 | if (ReductionProcessor::supportedIntrinsicProcReduction( |
811 | *reductionIntrinsic)) { |
812 | ReductionProcessor::ReductionIdentifier redId = |
813 | ReductionProcessor::getReductionType(*reductionIntrinsic); |
814 | for (const Object &object : objectList) { |
815 | const Fortran::semantics::Symbol *symbol = object.id(); |
816 | mlir::Value symVal = converter.getSymbolAddress(*symbol); |
817 | if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) |
818 | symVal = declOp.getBase(); |
819 | auto redType = symVal.getType().cast<fir::ReferenceType>(); |
820 | if (!redType.getEleTy().isIntOrIndexOrFloat()) |
821 | TODO(currentLocation, "User Defined Reduction on non-trivial type" ); |
822 | decl = createDeclareReduction( |
823 | firOpBuilder, |
824 | getReductionName(getRealName(*reductionIntrinsic).ToString(), |
825 | firOpBuilder.getKindMap(), redType, isByRef), |
826 | redId, redType, currentLocation, isByRef); |
827 | reductionDeclSymbols.push_back(mlir::SymbolRefAttr::get( |
828 | firOpBuilder.getContext(), decl.getSymName())); |
829 | } |
830 | } |
831 | } |
832 | } |
833 | |
834 | const Fortran::semantics::SourceName |
835 | ReductionProcessor::getRealName(const Fortran::semantics::Symbol *symbol) { |
836 | return symbol->GetUltimate().name(); |
837 | } |
838 | |
839 | const Fortran::semantics::SourceName |
840 | ReductionProcessor::getRealName(const omp::clause::ProcedureDesignator &pd) { |
841 | return getRealName(pd.v.id()); |
842 | } |
843 | |
844 | int ReductionProcessor::getOperationIdentity(ReductionIdentifier redId, |
845 | mlir::Location loc) { |
846 | switch (redId) { |
847 | case ReductionIdentifier::ADD: |
848 | case ReductionIdentifier::OR: |
849 | case ReductionIdentifier::NEQV: |
850 | return 0; |
851 | case ReductionIdentifier::MULTIPLY: |
852 | case ReductionIdentifier::AND: |
853 | case ReductionIdentifier::EQV: |
854 | return 1; |
855 | default: |
856 | TODO(loc, "Reduction of some intrinsic operators is not supported" ); |
857 | } |
858 | } |
859 | |
860 | } // namespace omp |
861 | } // namespace lower |
862 | } // namespace Fortran |
863 | |