1 | //===-- ReductionProcessor.cpp ----------------------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "ReductionProcessor.h" |
14 | |
15 | #include "flang/Lower/AbstractConverter.h" |
16 | #include "flang/Lower/ConvertType.h" |
17 | #include "flang/Lower/Support/PrivateReductionUtils.h" |
18 | #include "flang/Lower/SymbolMap.h" |
19 | #include "flang/Optimizer/Builder/Complex.h" |
20 | #include "flang/Optimizer/Builder/HLFIRTools.h" |
21 | #include "flang/Optimizer/Builder/Todo.h" |
22 | #include "flang/Optimizer/Dialect/FIRType.h" |
23 | #include "flang/Optimizer/HLFIR/HLFIROps.h" |
24 | #include "flang/Optimizer/Support/FatalError.h" |
25 | #include "flang/Parser/tools.h" |
26 | #include "mlir/Dialect/OpenMP/OpenMPDialect.h" |
27 | #include "llvm/Support/CommandLine.h" |
28 | #include <type_traits> |
29 | |
30 | static llvm::cl::opt<bool> forceByrefReduction( |
31 | "force-byref-reduction" , |
32 | llvm::cl::desc("Pass all reduction arguments by reference" ), |
33 | llvm::cl::Hidden); |
34 | |
35 | using ReductionModifier = |
36 | Fortran::lower::omp::clause::Reduction::ReductionModifier; |
37 | |
38 | namespace Fortran { |
39 | namespace lower { |
40 | namespace omp { |
41 | |
42 | // explicit template declarations |
43 | template void |
44 | ReductionProcessor::processReductionArguments<omp::clause::Reduction>( |
45 | mlir::Location currentLocation, lower::AbstractConverter &converter, |
46 | const omp::clause::Reduction &reduction, |
47 | llvm::SmallVectorImpl<mlir::Value> &reductionVars, |
48 | llvm::SmallVectorImpl<bool> &reduceVarByRef, |
49 | llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols, |
50 | llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols, |
51 | mlir::omp::ReductionModifierAttr *reductionMod); |
52 | |
53 | template void |
54 | ReductionProcessor::processReductionArguments<omp::clause::TaskReduction>( |
55 | mlir::Location currentLocation, lower::AbstractConverter &converter, |
56 | const omp::clause::TaskReduction &reduction, |
57 | llvm::SmallVectorImpl<mlir::Value> &reductionVars, |
58 | llvm::SmallVectorImpl<bool> &reduceVarByRef, |
59 | llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols, |
60 | llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols, |
61 | mlir::omp::ReductionModifierAttr *reductionMod); |
62 | |
63 | template void |
64 | ReductionProcessor::processReductionArguments<omp::clause::InReduction>( |
65 | mlir::Location currentLocation, lower::AbstractConverter &converter, |
66 | const omp::clause::InReduction &reduction, |
67 | llvm::SmallVectorImpl<mlir::Value> &reductionVars, |
68 | llvm::SmallVectorImpl<bool> &reduceVarByRef, |
69 | llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols, |
70 | llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols, |
71 | mlir::omp::ReductionModifierAttr *reductionMod); |
72 | |
73 | ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType( |
74 | const omp::clause::ProcedureDesignator &pd) { |
75 | auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>( |
76 | getRealName(pd.v.sym()).ToString()) |
77 | .Case("max" , ReductionIdentifier::MAX) |
78 | .Case("min" , ReductionIdentifier::MIN) |
79 | .Case("iand" , ReductionIdentifier::IAND) |
80 | .Case("ior" , ReductionIdentifier::IOR) |
81 | .Case("ieor" , ReductionIdentifier::IEOR) |
82 | .Default(std::nullopt); |
83 | assert(redType && "Invalid Reduction" ); |
84 | return *redType; |
85 | } |
86 | |
87 | ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType( |
88 | omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp) { |
89 | switch (intrinsicOp) { |
90 | case omp::clause::DefinedOperator::IntrinsicOperator::Add: |
91 | return ReductionIdentifier::ADD; |
92 | case omp::clause::DefinedOperator::IntrinsicOperator::Subtract: |
93 | return ReductionIdentifier::SUBTRACT; |
94 | case omp::clause::DefinedOperator::IntrinsicOperator::Multiply: |
95 | return ReductionIdentifier::MULTIPLY; |
96 | case omp::clause::DefinedOperator::IntrinsicOperator::AND: |
97 | return ReductionIdentifier::AND; |
98 | case omp::clause::DefinedOperator::IntrinsicOperator::EQV: |
99 | return ReductionIdentifier::EQV; |
100 | case omp::clause::DefinedOperator::IntrinsicOperator::OR: |
101 | return ReductionIdentifier::OR; |
102 | case omp::clause::DefinedOperator::IntrinsicOperator::NEQV: |
103 | return ReductionIdentifier::NEQV; |
104 | default: |
105 | llvm_unreachable("unexpected intrinsic operator in reduction" ); |
106 | } |
107 | } |
108 | |
109 | bool ReductionProcessor::supportedIntrinsicProcReduction( |
110 | const omp::clause::ProcedureDesignator &pd) { |
111 | semantics::Symbol *sym = pd.v.sym(); |
112 | if (!sym->GetUltimate().attrs().test(semantics::Attr::INTRINSIC)) |
113 | return false; |
114 | auto redType = llvm::StringSwitch<bool>(getRealName(sym).ToString()) |
115 | .Case("max" , true) |
116 | .Case("min" , true) |
117 | .Case("iand" , true) |
118 | .Case("ior" , true) |
119 | .Case("ieor" , true) |
120 | .Default(false); |
121 | return redType; |
122 | } |
123 | |
124 | std::string |
125 | ReductionProcessor::getReductionName(llvm::StringRef name, |
126 | const fir::KindMapping &kindMap, |
127 | mlir::Type ty, bool isByRef) { |
128 | ty = fir::unwrapRefType(ty); |
129 | |
130 | // extra string to distinguish reduction functions for variables passed by |
131 | // reference |
132 | llvm::StringRef byrefAddition{"" }; |
133 | if (isByRef) |
134 | byrefAddition = "_byref" ; |
135 | |
136 | return fir::getTypeAsString(ty, kindMap, (name + byrefAddition).str()); |
137 | } |
138 | |
139 | std::string ReductionProcessor::getReductionName( |
140 | omp::clause::DefinedOperator::IntrinsicOperator intrinsicOp, |
141 | const fir::KindMapping &kindMap, mlir::Type ty, bool isByRef) { |
142 | std::string reductionName; |
143 | |
144 | switch (intrinsicOp) { |
145 | case omp::clause::DefinedOperator::IntrinsicOperator::Add: |
146 | reductionName = "add_reduction" ; |
147 | break; |
148 | case omp::clause::DefinedOperator::IntrinsicOperator::Multiply: |
149 | reductionName = "multiply_reduction" ; |
150 | break; |
151 | case omp::clause::DefinedOperator::IntrinsicOperator::AND: |
152 | return "and_reduction" ; |
153 | case omp::clause::DefinedOperator::IntrinsicOperator::EQV: |
154 | return "eqv_reduction" ; |
155 | case omp::clause::DefinedOperator::IntrinsicOperator::OR: |
156 | return "or_reduction" ; |
157 | case omp::clause::DefinedOperator::IntrinsicOperator::NEQV: |
158 | return "neqv_reduction" ; |
159 | default: |
160 | reductionName = "other_reduction" ; |
161 | break; |
162 | } |
163 | |
164 | return getReductionName(reductionName, kindMap, ty, isByRef); |
165 | } |
166 | |
167 | mlir::Value |
168 | ReductionProcessor::getReductionInitValue(mlir::Location loc, mlir::Type type, |
169 | ReductionIdentifier redId, |
170 | fir::FirOpBuilder &builder) { |
171 | type = fir::unwrapRefType(type); |
172 | if (!fir::isa_integer(type) && !fir::isa_real(type) && |
173 | !fir::isa_complex(type) && !mlir::isa<fir::LogicalType>(type)) |
174 | TODO(loc, "Reduction of some types is not supported" ); |
175 | switch (redId) { |
176 | case ReductionIdentifier::MAX: { |
177 | if (auto ty = mlir::dyn_cast<mlir::FloatType>(type)) { |
178 | const llvm::fltSemantics &sem = ty.getFloatSemantics(); |
179 | return builder.createRealConstant( |
180 | loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/true)); |
181 | } |
182 | unsigned bits = type.getIntOrFloatBitWidth(); |
183 | int64_t minInt = llvm::APInt::getSignedMinValue(bits).getSExtValue(); |
184 | return builder.createIntegerConstant(loc, type, minInt); |
185 | } |
186 | case ReductionIdentifier::MIN: { |
187 | if (auto ty = mlir::dyn_cast<mlir::FloatType>(type)) { |
188 | const llvm::fltSemantics &sem = ty.getFloatSemantics(); |
189 | return builder.createRealConstant( |
190 | loc, type, llvm::APFloat::getLargest(sem, /*Negative=*/false)); |
191 | } |
192 | unsigned bits = type.getIntOrFloatBitWidth(); |
193 | int64_t maxInt = llvm::APInt::getSignedMaxValue(bits).getSExtValue(); |
194 | return builder.createIntegerConstant(loc, type, maxInt); |
195 | } |
196 | case ReductionIdentifier::IOR: { |
197 | unsigned bits = type.getIntOrFloatBitWidth(); |
198 | int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue(); |
199 | return builder.createIntegerConstant(loc, type, zeroInt); |
200 | } |
201 | case ReductionIdentifier::IEOR: { |
202 | unsigned bits = type.getIntOrFloatBitWidth(); |
203 | int64_t zeroInt = llvm::APInt::getZero(bits).getSExtValue(); |
204 | return builder.createIntegerConstant(loc, type, zeroInt); |
205 | } |
206 | case ReductionIdentifier::IAND: { |
207 | unsigned bits = type.getIntOrFloatBitWidth(); |
208 | int64_t allOnInt = llvm::APInt::getAllOnes(bits).getSExtValue(); |
209 | return builder.createIntegerConstant(loc, type, allOnInt); |
210 | } |
211 | case ReductionIdentifier::ADD: |
212 | case ReductionIdentifier::MULTIPLY: |
213 | case ReductionIdentifier::AND: |
214 | case ReductionIdentifier::OR: |
215 | case ReductionIdentifier::EQV: |
216 | case ReductionIdentifier::NEQV: |
217 | if (auto cplxTy = mlir::dyn_cast<mlir::ComplexType>(type)) { |
218 | mlir::Type realTy = cplxTy.getElementType(); |
219 | mlir::Value initRe = builder.createRealConstant( |
220 | loc, realTy, getOperationIdentity(redId, loc)); |
221 | mlir::Value initIm = builder.createRealConstant(loc, realTy, 0); |
222 | |
223 | return fir::factory::Complex{builder, loc}.createComplex(type, initRe, |
224 | initIm); |
225 | } |
226 | if (mlir::isa<mlir::FloatType>(type)) |
227 | return builder.create<mlir::arith::ConstantOp>( |
228 | loc, type, |
229 | builder.getFloatAttr(type, (double)getOperationIdentity(redId, loc))); |
230 | |
231 | if (mlir::isa<fir::LogicalType>(type)) { |
232 | mlir::Value intConst = builder.create<mlir::arith::ConstantOp>( |
233 | loc, builder.getI1Type(), |
234 | builder.getIntegerAttr(builder.getI1Type(), |
235 | getOperationIdentity(redId, loc))); |
236 | return builder.createConvert(loc, type, intConst); |
237 | } |
238 | |
239 | return builder.create<mlir::arith::ConstantOp>( |
240 | loc, type, |
241 | builder.getIntegerAttr(type, getOperationIdentity(redId, loc))); |
242 | case ReductionIdentifier::ID: |
243 | case ReductionIdentifier::USER_DEF_OP: |
244 | case ReductionIdentifier::SUBTRACT: |
245 | TODO(loc, "Reduction of some identifier types is not supported" ); |
246 | } |
247 | llvm_unreachable("Unhandled Reduction identifier : getReductionInitValue" ); |
248 | } |
249 | |
250 | mlir::Value ReductionProcessor::createScalarCombiner( |
251 | fir::FirOpBuilder &builder, mlir::Location loc, ReductionIdentifier redId, |
252 | mlir::Type type, mlir::Value op1, mlir::Value op2) { |
253 | mlir::Value reductionOp; |
254 | type = fir::unwrapRefType(type); |
255 | switch (redId) { |
256 | case ReductionIdentifier::MAX: |
257 | reductionOp = |
258 | getReductionOperation<mlir::arith::MaxNumFOp, mlir::arith::MaxSIOp>( |
259 | builder, type, loc, op1, op2); |
260 | break; |
261 | case ReductionIdentifier::MIN: |
262 | reductionOp = |
263 | getReductionOperation<mlir::arith::MinNumFOp, mlir::arith::MinSIOp>( |
264 | builder, type, loc, op1, op2); |
265 | break; |
266 | case ReductionIdentifier::IOR: |
267 | assert((type.isIntOrIndex()) && "only integer is expected" ); |
268 | reductionOp = builder.create<mlir::arith::OrIOp>(loc, op1, op2); |
269 | break; |
270 | case ReductionIdentifier::IEOR: |
271 | assert((type.isIntOrIndex()) && "only integer is expected" ); |
272 | reductionOp = builder.create<mlir::arith::XOrIOp>(loc, op1, op2); |
273 | break; |
274 | case ReductionIdentifier::IAND: |
275 | assert((type.isIntOrIndex()) && "only integer is expected" ); |
276 | reductionOp = builder.create<mlir::arith::AndIOp>(loc, op1, op2); |
277 | break; |
278 | case ReductionIdentifier::ADD: |
279 | reductionOp = |
280 | getReductionOperation<mlir::arith::AddFOp, mlir::arith::AddIOp, |
281 | fir::AddcOp>(builder, type, loc, op1, op2); |
282 | break; |
283 | case ReductionIdentifier::MULTIPLY: |
284 | reductionOp = |
285 | getReductionOperation<mlir::arith::MulFOp, mlir::arith::MulIOp, |
286 | fir::MulcOp>(builder, type, loc, op1, op2); |
287 | break; |
288 | case ReductionIdentifier::AND: { |
289 | mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); |
290 | mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); |
291 | |
292 | mlir::Value andiOp = builder.create<mlir::arith::AndIOp>(loc, op1I1, op2I1); |
293 | |
294 | reductionOp = builder.createConvert(loc, type, andiOp); |
295 | break; |
296 | } |
297 | case ReductionIdentifier::OR: { |
298 | mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); |
299 | mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); |
300 | |
301 | mlir::Value oriOp = builder.create<mlir::arith::OrIOp>(loc, op1I1, op2I1); |
302 | |
303 | reductionOp = builder.createConvert(loc, type, oriOp); |
304 | break; |
305 | } |
306 | case ReductionIdentifier::EQV: { |
307 | mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); |
308 | mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); |
309 | |
310 | mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>( |
311 | loc, mlir::arith::CmpIPredicate::eq, op1I1, op2I1); |
312 | |
313 | reductionOp = builder.createConvert(loc, type, cmpiOp); |
314 | break; |
315 | } |
316 | case ReductionIdentifier::NEQV: { |
317 | mlir::Value op1I1 = builder.createConvert(loc, builder.getI1Type(), op1); |
318 | mlir::Value op2I1 = builder.createConvert(loc, builder.getI1Type(), op2); |
319 | |
320 | mlir::Value cmpiOp = builder.create<mlir::arith::CmpIOp>( |
321 | loc, mlir::arith::CmpIPredicate::ne, op1I1, op2I1); |
322 | |
323 | reductionOp = builder.createConvert(loc, type, cmpiOp); |
324 | break; |
325 | } |
326 | default: |
327 | TODO(loc, "Reduction of some intrinsic operators is not supported" ); |
328 | } |
329 | |
330 | return reductionOp; |
331 | } |
332 | |
333 | /// Create reduction combiner region for reduction variables which are boxed |
334 | /// arrays |
335 | static void genBoxCombiner(fir::FirOpBuilder &builder, mlir::Location loc, |
336 | ReductionProcessor::ReductionIdentifier redId, |
337 | fir::BaseBoxType boxTy, mlir::Value lhs, |
338 | mlir::Value rhs) { |
339 | fir::SequenceType seqTy = mlir::dyn_cast_or_null<fir::SequenceType>( |
340 | fir::unwrapRefType(boxTy.getEleTy())); |
341 | fir::HeapType heapTy = |
342 | mlir::dyn_cast_or_null<fir::HeapType>(boxTy.getEleTy()); |
343 | fir::PointerType ptrTy = |
344 | mlir::dyn_cast_or_null<fir::PointerType>(boxTy.getEleTy()); |
345 | if ((!seqTy || seqTy.hasUnknownShape()) && !heapTy && !ptrTy) |
346 | TODO(loc, "Unsupported boxed type in OpenMP reduction" ); |
347 | |
348 | // load fir.ref<fir.box<...>> |
349 | mlir::Value lhsAddr = lhs; |
350 | lhs = builder.create<fir::LoadOp>(loc, lhs); |
351 | rhs = builder.create<fir::LoadOp>(loc, rhs); |
352 | |
353 | if ((heapTy || ptrTy) && !seqTy) { |
354 | // get box contents (heap pointers) |
355 | lhs = builder.create<fir::BoxAddrOp>(loc, lhs); |
356 | rhs = builder.create<fir::BoxAddrOp>(loc, rhs); |
357 | mlir::Value lhsValAddr = lhs; |
358 | |
359 | // load heap pointers |
360 | lhs = builder.create<fir::LoadOp>(loc, lhs); |
361 | rhs = builder.create<fir::LoadOp>(loc, rhs); |
362 | |
363 | mlir::Type eleTy = heapTy ? heapTy.getEleTy() : ptrTy.getEleTy(); |
364 | |
365 | mlir::Value result = ReductionProcessor::createScalarCombiner( |
366 | builder, loc, redId, eleTy, lhs, rhs); |
367 | builder.create<fir::StoreOp>(loc, result, lhsValAddr); |
368 | builder.create<mlir::omp::YieldOp>(loc, lhsAddr); |
369 | return; |
370 | } |
371 | |
372 | // Get ShapeShift with default lower bounds. This makes it possible to use |
373 | // unmodified LoopNest's indices with ArrayCoorOp. |
374 | fir::ShapeShiftOp shapeShift = |
375 | getShapeShift(builder, loc, lhs, |
376 | /*cannotHaveNonDefaultLowerBounds=*/false, |
377 | /*useDefaultLowerBounds=*/true); |
378 | |
379 | // Iterate over array elements, applying the equivalent scalar reduction: |
380 | |
381 | // F2018 5.4.10.2: Unallocated allocatable variables may not be referenced |
382 | // and so no null check is needed here before indexing into the (possibly |
383 | // allocatable) arrays. |
384 | |
385 | // A hlfir::elemental here gets inlined with a temporary so create the |
386 | // loop nest directly. |
387 | // This function already controls all of the code in this region so we |
388 | // know this won't miss any opportuinties for clever elemental inlining |
389 | hlfir::LoopNest nest = hlfir::genLoopNest( |
390 | loc, builder, shapeShift.getExtents(), /*isUnordered=*/true); |
391 | builder.setInsertionPointToStart(nest.body); |
392 | const bool seqIsVolatile = fir::isa_volatile_type(seqTy.getEleTy()); |
393 | mlir::Type refTy = fir::ReferenceType::get(seqTy.getEleTy(), seqIsVolatile); |
394 | auto lhsEleAddr = builder.create<fir::ArrayCoorOp>( |
395 | loc, refTy, lhs, shapeShift, /*slice=*/mlir::Value{}, |
396 | nest.oneBasedIndices, /*typeparms=*/mlir::ValueRange{}); |
397 | auto rhsEleAddr = builder.create<fir::ArrayCoorOp>( |
398 | loc, refTy, rhs, shapeShift, /*slice=*/mlir::Value{}, |
399 | nest.oneBasedIndices, /*typeparms=*/mlir::ValueRange{}); |
400 | auto lhsEle = builder.create<fir::LoadOp>(loc, lhsEleAddr); |
401 | auto rhsEle = builder.create<fir::LoadOp>(loc, rhsEleAddr); |
402 | mlir::Value scalarReduction = ReductionProcessor::createScalarCombiner( |
403 | builder, loc, redId, refTy, lhsEle, rhsEle); |
404 | builder.create<fir::StoreOp>(loc, scalarReduction, lhsEleAddr); |
405 | |
406 | builder.setInsertionPointAfter(nest.outerOp); |
407 | builder.create<mlir::omp::YieldOp>(loc, lhsAddr); |
408 | } |
409 | |
410 | // generate combiner region for reduction operations |
411 | static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc, |
412 | ReductionProcessor::ReductionIdentifier redId, |
413 | mlir::Type ty, mlir::Value lhs, mlir::Value rhs, |
414 | bool isByRef) { |
415 | ty = fir::unwrapRefType(ty); |
416 | |
417 | if (fir::isa_trivial(ty)) { |
418 | mlir::Value lhsLoaded = builder.loadIfRef(loc, lhs); |
419 | mlir::Value rhsLoaded = builder.loadIfRef(loc, rhs); |
420 | |
421 | mlir::Value result = ReductionProcessor::createScalarCombiner( |
422 | builder, loc, redId, ty, lhsLoaded, rhsLoaded); |
423 | if (isByRef) { |
424 | builder.create<fir::StoreOp>(loc, result, lhs); |
425 | builder.create<mlir::omp::YieldOp>(loc, lhs); |
426 | } else { |
427 | builder.create<mlir::omp::YieldOp>(loc, result); |
428 | } |
429 | return; |
430 | } |
431 | // all arrays should have been boxed |
432 | if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) { |
433 | genBoxCombiner(builder, loc, redId, boxTy, lhs, rhs); |
434 | return; |
435 | } |
436 | |
437 | TODO(loc, "OpenMP genCombiner for unsupported reduction variable type" ); |
438 | } |
439 | |
440 | // like fir::unwrapSeqOrBoxedSeqType except it also works for non-sequence boxes |
441 | static mlir::Type unwrapSeqOrBoxedType(mlir::Type ty) { |
442 | if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(ty)) |
443 | return seqTy.getEleTy(); |
444 | if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(ty)) { |
445 | auto eleTy = fir::unwrapRefType(boxTy.getEleTy()); |
446 | if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(eleTy)) |
447 | return seqTy.getEleTy(); |
448 | return eleTy; |
449 | } |
450 | return ty; |
451 | } |
452 | |
453 | static void createReductionAllocAndInitRegions( |
454 | AbstractConverter &converter, mlir::Location loc, |
455 | mlir::omp::DeclareReductionOp &reductionDecl, |
456 | const ReductionProcessor::ReductionIdentifier redId, mlir::Type type, |
457 | bool isByRef) { |
458 | fir::FirOpBuilder &builder = converter.getFirOpBuilder(); |
459 | auto yield = [&](mlir::Value ret) { |
460 | builder.create<mlir::omp::YieldOp>(loc, ret); |
461 | }; |
462 | |
463 | mlir::Block *allocBlock = nullptr; |
464 | mlir::Block *initBlock = nullptr; |
465 | if (isByRef) { |
466 | allocBlock = |
467 | builder.createBlock(&reductionDecl.getAllocRegion(), |
468 | reductionDecl.getAllocRegion().end(), {}, {}); |
469 | initBlock = builder.createBlock(&reductionDecl.getInitializerRegion(), |
470 | reductionDecl.getInitializerRegion().end(), |
471 | {type, type}, {loc, loc}); |
472 | } else { |
473 | initBlock = builder.createBlock(&reductionDecl.getInitializerRegion(), |
474 | reductionDecl.getInitializerRegion().end(), |
475 | {type}, {loc}); |
476 | } |
477 | |
478 | mlir::Type ty = fir::unwrapRefType(type); |
479 | builder.setInsertionPointToEnd(initBlock); |
480 | mlir::Value initValue = ReductionProcessor::getReductionInitValue( |
481 | loc, unwrapSeqOrBoxedType(ty), redId, builder); |
482 | |
483 | if (isByRef) { |
484 | populateByRefInitAndCleanupRegions( |
485 | converter, loc, type, initValue, initBlock, |
486 | reductionDecl.getInitializerAllocArg(), |
487 | reductionDecl.getInitializerMoldArg(), reductionDecl.getCleanupRegion(), |
488 | DeclOperationKind::Reduction); |
489 | } |
490 | |
491 | if (fir::isa_trivial(ty)) { |
492 | if (isByRef) { |
493 | // alloc region |
494 | builder.setInsertionPointToEnd(allocBlock); |
495 | mlir::Value alloca = builder.create<fir::AllocaOp>(loc, ty); |
496 | yield(alloca); |
497 | return; |
498 | } |
499 | // by val |
500 | yield(initValue); |
501 | return; |
502 | } |
503 | assert(isByRef && "passing non-trivial types by val is unsupported" ); |
504 | |
505 | // alloc region |
506 | builder.setInsertionPointToEnd(allocBlock); |
507 | mlir::Value boxAlloca = builder.create<fir::AllocaOp>(loc, ty); |
508 | yield(boxAlloca); |
509 | } |
510 | |
511 | mlir::omp::DeclareReductionOp ReductionProcessor::createDeclareReduction( |
512 | AbstractConverter &converter, llvm::StringRef reductionOpName, |
513 | const ReductionIdentifier redId, mlir::Type type, mlir::Location loc, |
514 | bool isByRef) { |
515 | fir::FirOpBuilder &builder = converter.getFirOpBuilder(); |
516 | mlir::OpBuilder::InsertionGuard guard(builder); |
517 | mlir::ModuleOp module = builder.getModule(); |
518 | |
519 | assert(!reductionOpName.empty()); |
520 | |
521 | auto decl = |
522 | module.lookupSymbol<mlir::omp::DeclareReductionOp>(reductionOpName); |
523 | if (decl) |
524 | return decl; |
525 | |
526 | mlir::OpBuilder modBuilder(module.getBodyRegion()); |
527 | mlir::Type valTy = fir::unwrapRefType(type); |
528 | if (!isByRef) |
529 | type = valTy; |
530 | |
531 | decl = modBuilder.create<mlir::omp::DeclareReductionOp>(loc, reductionOpName, |
532 | type); |
533 | createReductionAllocAndInitRegions(converter, loc, decl, redId, type, |
534 | isByRef); |
535 | |
536 | builder.createBlock(&decl.getReductionRegion(), |
537 | decl.getReductionRegion().end(), {type, type}, |
538 | {loc, loc}); |
539 | |
540 | builder.setInsertionPointToEnd(&decl.getReductionRegion().back()); |
541 | mlir::Value op1 = decl.getReductionRegion().front().getArgument(0); |
542 | mlir::Value op2 = decl.getReductionRegion().front().getArgument(1); |
543 | genCombiner(builder, loc, redId, type, op1, op2, isByRef); |
544 | |
545 | return decl; |
546 | } |
547 | |
548 | static bool doReductionByRef(mlir::Value reductionVar) { |
549 | if (forceByrefReduction) |
550 | return true; |
551 | |
552 | if (auto declare = |
553 | mlir::dyn_cast<hlfir::DeclareOp>(reductionVar.getDefiningOp())) |
554 | reductionVar = declare.getMemref(); |
555 | |
556 | if (!fir::isa_trivial(fir::unwrapRefType(reductionVar.getType()))) |
557 | return true; |
558 | |
559 | return false; |
560 | } |
561 | |
562 | mlir::omp::ReductionModifier translateReductionModifier(ReductionModifier mod) { |
563 | switch (mod) { |
564 | case ReductionModifier::Default: |
565 | return mlir::omp::ReductionModifier::defaultmod; |
566 | case ReductionModifier::Inscan: |
567 | return mlir::omp::ReductionModifier::inscan; |
568 | case ReductionModifier::Task: |
569 | return mlir::omp::ReductionModifier::task; |
570 | } |
571 | return mlir::omp::ReductionModifier::defaultmod; |
572 | } |
573 | |
574 | template <class T> |
575 | void ReductionProcessor::processReductionArguments( |
576 | mlir::Location currentLocation, lower::AbstractConverter &converter, |
577 | const T &reduction, llvm::SmallVectorImpl<mlir::Value> &reductionVars, |
578 | llvm::SmallVectorImpl<bool> &reduceVarByRef, |
579 | llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols, |
580 | llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols, |
581 | mlir::omp::ReductionModifierAttr *reductionMod) { |
582 | fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); |
583 | |
584 | if constexpr (std::is_same_v<T, omp::clause::Reduction>) { |
585 | auto mod = std::get<std::optional<ReductionModifier>>(reduction.t); |
586 | if (mod.has_value()) { |
587 | if (mod.value() == ReductionModifier::Task) |
588 | TODO(currentLocation, "Reduction modifier `task` is not supported" ); |
589 | else |
590 | *reductionMod = mlir::omp::ReductionModifierAttr::get( |
591 | firOpBuilder.getContext(), translateReductionModifier(mod.value())); |
592 | } |
593 | } |
594 | |
595 | mlir::omp::DeclareReductionOp decl; |
596 | const auto &redOperatorList{ |
597 | std::get<typename T::ReductionIdentifiers>(reduction.t)}; |
598 | assert(redOperatorList.size() == 1 && "Expecting single operator" ); |
599 | const auto &redOperator = redOperatorList.front(); |
600 | const auto &objectList{std::get<omp::ObjectList>(reduction.t)}; |
601 | |
602 | if (!std::holds_alternative<omp::clause::DefinedOperator>(redOperator.u)) { |
603 | if (const auto *reductionIntrinsic = |
604 | std::get_if<omp::clause::ProcedureDesignator>(&redOperator.u)) { |
605 | if (!ReductionProcessor::supportedIntrinsicProcReduction( |
606 | *reductionIntrinsic)) { |
607 | return; |
608 | } |
609 | } else { |
610 | return; |
611 | } |
612 | } |
613 | |
614 | // Reduction variable processing common to both intrinsic operators and |
615 | // procedure designators |
616 | fir::FirOpBuilder &builder = converter.getFirOpBuilder(); |
617 | for (const Object &object : objectList) { |
618 | const semantics::Symbol *symbol = object.sym(); |
619 | reductionSymbols.push_back(symbol); |
620 | mlir::Value symVal = converter.getSymbolAddress(*symbol); |
621 | mlir::Type eleType; |
622 | auto refType = mlir::dyn_cast_or_null<fir::ReferenceType>(symVal.getType()); |
623 | if (refType) |
624 | eleType = refType.getEleTy(); |
625 | else |
626 | eleType = symVal.getType(); |
627 | |
628 | // all arrays must be boxed so that we have convenient access to all the |
629 | // information needed to iterate over the array |
630 | if (mlir::isa<fir::SequenceType>(eleType)) { |
631 | // For Host associated symbols, use `SymbolBox` instead |
632 | lower::SymbolBox symBox = converter.lookupOneLevelUpSymbol(*symbol); |
633 | hlfir::Entity entity{symBox.getAddr()}; |
634 | entity = genVariableBox(currentLocation, builder, entity); |
635 | mlir::Value box = entity.getBase(); |
636 | |
637 | // Always pass the box by reference so that the OpenMP dialect |
638 | // verifiers don't need to know anything about fir.box |
639 | auto alloca = |
640 | builder.create<fir::AllocaOp>(currentLocation, box.getType()); |
641 | builder.create<fir::StoreOp>(currentLocation, box, alloca); |
642 | |
643 | symVal = alloca; |
644 | } else if (mlir::isa<fir::BaseBoxType>(symVal.getType())) { |
645 | // boxed arrays are passed as values not by reference. Unfortunately, |
646 | // we can't pass a box by value to omp.redution_declare, so turn it |
647 | // into a reference |
648 | |
649 | auto alloca = |
650 | builder.create<fir::AllocaOp>(currentLocation, symVal.getType()); |
651 | builder.create<fir::StoreOp>(currentLocation, symVal, alloca); |
652 | symVal = alloca; |
653 | } else if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>()) { |
654 | symVal = declOp.getBase(); |
655 | } |
656 | |
657 | // this isn't the same as the by-val and by-ref passing later in the |
658 | // pipeline. Both styles assume that the variable is a reference at |
659 | // this point |
660 | assert(fir::isa_ref_type(symVal.getType()) && |
661 | "reduction input var is passed by reference" ); |
662 | mlir::Type elementType = fir::dyn_cast_ptrEleTy(symVal.getType()); |
663 | const bool symIsVolatile = fir::isa_volatile_type(symVal.getType()); |
664 | mlir::Type refTy = fir::ReferenceType::get(elementType, symIsVolatile); |
665 | |
666 | reductionVars.push_back( |
667 | builder.createConvert(currentLocation, refTy, symVal)); |
668 | reduceVarByRef.push_back(doReductionByRef(symVal)); |
669 | } |
670 | |
671 | for (auto [symVal, isByRef] : llvm::zip(reductionVars, reduceVarByRef)) { |
672 | auto redType = mlir::cast<fir::ReferenceType>(symVal.getType()); |
673 | const auto &kindMap = firOpBuilder.getKindMap(); |
674 | std::string reductionName; |
675 | ReductionIdentifier redId; |
676 | mlir::Type redNameTy = redType; |
677 | if (mlir::isa<fir::LogicalType>(redType.getEleTy())) |
678 | redNameTy = builder.getI1Type(); |
679 | |
680 | if (const auto &redDefinedOp = |
681 | std::get_if<omp::clause::DefinedOperator>(&redOperator.u)) { |
682 | const auto &intrinsicOp{ |
683 | std::get<omp::clause::DefinedOperator::IntrinsicOperator>( |
684 | redDefinedOp->u)}; |
685 | redId = getReductionType(intrinsicOp); |
686 | switch (redId) { |
687 | case ReductionIdentifier::ADD: |
688 | case ReductionIdentifier::MULTIPLY: |
689 | case ReductionIdentifier::AND: |
690 | case ReductionIdentifier::EQV: |
691 | case ReductionIdentifier::OR: |
692 | case ReductionIdentifier::NEQV: |
693 | break; |
694 | default: |
695 | TODO(currentLocation, |
696 | "Reduction of some intrinsic operators is not supported" ); |
697 | break; |
698 | } |
699 | |
700 | reductionName = |
701 | getReductionName(intrinsicOp, kindMap, redNameTy, isByRef); |
702 | } else if (const auto *reductionIntrinsic = |
703 | std::get_if<omp::clause::ProcedureDesignator>( |
704 | &redOperator.u)) { |
705 | if (!ReductionProcessor::supportedIntrinsicProcReduction( |
706 | *reductionIntrinsic)) { |
707 | TODO(currentLocation, "Unsupported intrinsic proc reduction" ); |
708 | } |
709 | redId = getReductionType(*reductionIntrinsic); |
710 | reductionName = |
711 | getReductionName(getRealName(*reductionIntrinsic).ToString(), kindMap, |
712 | redNameTy, isByRef); |
713 | } else { |
714 | TODO(currentLocation, "Unexpected reduction type" ); |
715 | } |
716 | |
717 | decl = createDeclareReduction(converter, reductionName, redId, redType, |
718 | currentLocation, isByRef); |
719 | reductionDeclSymbols.push_back( |
720 | mlir::SymbolRefAttr::get(firOpBuilder.getContext(), decl.getSymName())); |
721 | } |
722 | } |
723 | |
724 | const semantics::SourceName |
725 | ReductionProcessor::getRealName(const semantics::Symbol *symbol) { |
726 | return symbol->GetUltimate().name(); |
727 | } |
728 | |
729 | const semantics::SourceName |
730 | ReductionProcessor::getRealName(const omp::clause::ProcedureDesignator &pd) { |
731 | return getRealName(pd.v.sym()); |
732 | } |
733 | |
734 | int ReductionProcessor::getOperationIdentity(ReductionIdentifier redId, |
735 | mlir::Location loc) { |
736 | switch (redId) { |
737 | case ReductionIdentifier::ADD: |
738 | case ReductionIdentifier::OR: |
739 | case ReductionIdentifier::NEQV: |
740 | return 0; |
741 | case ReductionIdentifier::MULTIPLY: |
742 | case ReductionIdentifier::AND: |
743 | case ReductionIdentifier::EQV: |
744 | return 1; |
745 | default: |
746 | TODO(loc, "Reduction of some intrinsic operators is not supported" ); |
747 | } |
748 | } |
749 | |
750 | } // namespace omp |
751 | } // namespace lower |
752 | } // namespace Fortran |
753 | |