1//===- SCFToOpenMP.cpp - Structured Control Flow to OpenMP conversion -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to convert scf.parallel operations into OpenMP
10// parallel loops.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h"
15
16#include "mlir/Analysis/SliceAnalysis.h"
17#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
18#include "mlir/Dialect/Arith/IR/Arith.h"
19#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20#include "mlir/Dialect/MemRef/IR/MemRef.h"
21#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
22#include "mlir/Dialect/SCF/IR/SCF.h"
23#include "mlir/IR/SymbolTable.h"
24#include "mlir/Pass/Pass.h"
25#include "mlir/Transforms/DialectConversion.h"
26
27namespace mlir {
28#define GEN_PASS_DEF_CONVERTSCFTOOPENMPPASS
29#include "mlir/Conversion/Passes.h.inc"
30} // namespace mlir
31
32using namespace mlir;
33
34/// Matches a block containing a "simple" reduction. The expected shape of the
35/// block is as follows.
36///
37/// ^bb(%arg0, %arg1):
38/// %0 = OpTy(%arg0, %arg1)
39/// scf.reduce.return %0
40template <typename... OpTy>
41static bool matchSimpleReduction(Block &block) {
42 if (block.empty() || llvm::hasSingleElement(C&: block) ||
43 std::next(x: block.begin(), n: 2) != block.end())
44 return false;
45
46 if (block.getNumArguments() != 2)
47 return false;
48
49 SmallVector<Operation *, 4> combinerOps;
50 Value reducedVal = matchReduction(iterCarriedArgs: {block.getArguments()[1]},
51 /*redPos=*/0, combinerOps);
52
53 if (!reducedVal || !isa<BlockArgument>(Val: reducedVal) || combinerOps.size() != 1)
54 return false;
55
56 return isa<OpTy...>(combinerOps[0]) &&
57 isa<scf::ReduceReturnOp>(Val: block.back()) &&
58 block.front().getOperands() == block.getArguments();
59}
60
61/// Matches a block containing a select-based min/max reduction. The types of
62/// select and compare operations are provided as template arguments. The
63/// comparison predicates suitable for min and max are provided as function
64/// arguments. If a reduction is matched, `ifMin` will be set if the reduction
65/// compute the minimum and unset if it computes the maximum, otherwise it
66/// remains unmodified. The expected shape of the block is as follows.
67///
68/// ^bb(%arg0, %arg1):
69/// %0 = CompareOpTy(<one-of-predicates>, %arg0, %arg1)
70/// %1 = SelectOpTy(%0, %arg0, %arg1) // %arg0, %arg1 may be swapped here.
71/// scf.reduce.return %1
72template <
73 typename CompareOpTy, typename SelectOpTy,
74 typename Predicate = decltype(std::declval<CompareOpTy>().getPredicate())>
75static bool
76matchSelectReduction(Block &block, ArrayRef<Predicate> lessThanPredicates,
77 ArrayRef<Predicate> greaterThanPredicates, bool &isMin) {
78 static_assert(
79 llvm::is_one_of<SelectOpTy, arith::SelectOp, LLVM::SelectOp>::value,
80 "only arithmetic and llvm select ops are supported");
81
82 // Expect exactly three operations in the block.
83 if (block.empty() || llvm::hasSingleElement(C&: block) ||
84 std::next(x: block.begin(), n: 2) == block.end() ||
85 std::next(x: block.begin(), n: 3) != block.end())
86 return false;
87
88 // Check op kinds.
89 auto compare = dyn_cast<CompareOpTy>(block.front());
90 auto select = dyn_cast<SelectOpTy>(block.front().getNextNode());
91 auto terminator = dyn_cast<scf::ReduceReturnOp>(Val&: block.back());
92 if (!compare || !select || !terminator)
93 return false;
94
95 // Block arguments must be compared.
96 if (compare->getOperands() != block.getArguments())
97 return false;
98
99 // Detect whether the comparison is less-than or greater-than, otherwise bail.
100 bool isLess;
101 if (llvm::is_contained(lessThanPredicates, compare.getPredicate())) {
102 isLess = true;
103 } else if (llvm::is_contained(greaterThanPredicates,
104 compare.getPredicate())) {
105 isLess = false;
106 } else {
107 return false;
108 }
109
110 if (select.getCondition() != compare.getResult())
111 return false;
112
113 // Detect if the operands are swapped between cmpf and select. Match the
114 // comparison type with the requested type or with the opposite of the
115 // requested type if the operands are swapped. Use generic accessors because
116 // std and LLVM versions of select have different operand names but identical
117 // positions.
118 constexpr unsigned kTrueValue = 1;
119 constexpr unsigned kFalseValue = 2;
120 bool sameOperands = select.getOperand(kTrueValue) == compare.getLhs() &&
121 select.getOperand(kFalseValue) == compare.getRhs();
122 bool swappedOperands = select.getOperand(kTrueValue) == compare.getRhs() &&
123 select.getOperand(kFalseValue) == compare.getLhs();
124 if (!sameOperands && !swappedOperands)
125 return false;
126
127 if (select.getResult() != terminator.getResult())
128 return false;
129
130 // The reduction is a min if it uses less-than predicates with same operands
131 // or greather-than predicates with swapped operands. Similarly for max.
132 isMin = (isLess && sameOperands) || (!isLess && swappedOperands);
133 return isMin || (isLess & swappedOperands) || (!isLess && sameOperands);
134}
135
136/// Returns the float semantics for the given float type.
137static const llvm::fltSemantics &fltSemanticsForType(FloatType type) {
138 if (type.isF16())
139 return llvm::APFloat::IEEEhalf();
140 if (type.isF32())
141 return llvm::APFloat::IEEEsingle();
142 if (type.isF64())
143 return llvm::APFloat::IEEEdouble();
144 if (type.isF128())
145 return llvm::APFloat::IEEEquad();
146 if (type.isBF16())
147 return llvm::APFloat::BFloat();
148 if (type.isF80())
149 return llvm::APFloat::x87DoubleExtended();
150 llvm_unreachable("unknown float type");
151}
152
153/// Returns an attribute with the minimum (if `min` is set) or the maximum value
154/// (otherwise) for the given float type.
155static Attribute minMaxValueForFloat(Type type, bool min) {
156 auto fltType = cast<FloatType>(Val&: type);
157 return FloatAttr::get(
158 type, value: llvm::APFloat::getLargest(Sem: fltSemanticsForType(type: fltType), Negative: min));
159}
160
161/// Returns an attribute with the signed integer minimum (if `min` is set) or
162/// the maximum value (otherwise) for the given integer type, regardless of its
163/// signedness semantics (only the width is considered).
164static Attribute minMaxValueForSignedInt(Type type, bool min) {
165 auto intType = cast<IntegerType>(Val&: type);
166 unsigned bitwidth = intType.getWidth();
167 return IntegerAttr::get(type, value: min ? llvm::APInt::getSignedMinValue(numBits: bitwidth)
168 : llvm::APInt::getSignedMaxValue(numBits: bitwidth));
169}
170
171/// Returns an attribute with the unsigned integer minimum (if `min` is set) or
172/// the maximum value (otherwise) for the given integer type, regardless of its
173/// signedness semantics (only the width is considered).
174static Attribute minMaxValueForUnsignedInt(Type type, bool min) {
175 auto intType = cast<IntegerType>(Val&: type);
176 unsigned bitwidth = intType.getWidth();
177 return IntegerAttr::get(type, value: min ? llvm::APInt::getZero(numBits: bitwidth)
178 : llvm::APInt::getAllOnes(numBits: bitwidth));
179}
180
181/// Creates an OpenMP reduction declaration and inserts it into the provided
182/// symbol table. The declaration has a constant initializer with the neutral
183/// value `initValue`, and the `reductionIndex`-th reduction combiner carried
184/// over from `reduce`.
185static omp::DeclareReductionOp
186createDecl(PatternRewriter &builder, SymbolTable &symbolTable,
187 scf::ReduceOp reduce, int64_t reductionIndex, Attribute initValue) {
188 OpBuilder::InsertionGuard guard(builder);
189 Type type = reduce.getOperands()[reductionIndex].getType();
190 auto decl = builder.create<omp::DeclareReductionOp>(location: reduce.getLoc(),
191 args: "__scf_reduction", args&: type);
192 symbolTable.insert(symbol: decl);
193
194 builder.createBlock(parent: &decl.getInitializerRegion(),
195 insertPt: decl.getInitializerRegion().end(), argTypes: {type},
196 locs: {reduce.getOperands()[reductionIndex].getLoc()});
197 builder.setInsertionPointToEnd(&decl.getInitializerRegion().back());
198 Value init =
199 builder.create<LLVM::ConstantOp>(location: reduce.getLoc(), args&: type, args&: initValue);
200 builder.create<omp::YieldOp>(location: reduce.getLoc(), args&: init);
201
202 Operation *terminator =
203 &reduce.getReductions()[reductionIndex].front().back();
204 assert(isa<scf::ReduceReturnOp>(terminator) &&
205 "expected reduce op to be terminated by redure return");
206 builder.setInsertionPoint(terminator);
207 builder.replaceOpWithNewOp<omp::YieldOp>(op: terminator,
208 args: terminator->getOperands());
209 builder.inlineRegionBefore(region&: reduce.getReductions()[reductionIndex],
210 parent&: decl.getReductionRegion(),
211 before: decl.getReductionRegion().end());
212 return decl;
213}
214
215/// Adds an atomic reduction combiner to the given OpenMP reduction declaration
216/// using llvm.atomicrmw of the given kind.
217static omp::DeclareReductionOp addAtomicRMW(OpBuilder &builder,
218 LLVM::AtomicBinOp atomicKind,
219 omp::DeclareReductionOp decl,
220 scf::ReduceOp reduce,
221 int64_t reductionIndex) {
222 OpBuilder::InsertionGuard guard(builder);
223 auto ptrType = LLVM::LLVMPointerType::get(context: builder.getContext());
224 Location reduceOperandLoc = reduce.getOperands()[reductionIndex].getLoc();
225 builder.createBlock(parent: &decl.getAtomicReductionRegion(),
226 insertPt: decl.getAtomicReductionRegion().end(), argTypes: {ptrType, ptrType},
227 locs: {reduceOperandLoc, reduceOperandLoc});
228 Block *atomicBlock = &decl.getAtomicReductionRegion().back();
229 builder.setInsertionPointToEnd(atomicBlock);
230 Value loaded = builder.create<LLVM::LoadOp>(location: reduce.getLoc(), args: decl.getType(),
231 args: atomicBlock->getArgument(i: 1));
232 builder.create<LLVM::AtomicRMWOp>(location: reduce.getLoc(), args&: atomicKind,
233 args: atomicBlock->getArgument(i: 0), args&: loaded,
234 args: LLVM::AtomicOrdering::monotonic);
235 builder.create<omp::YieldOp>(location: reduce.getLoc(), args: ArrayRef<Value>());
236 return decl;
237}
238
239/// Creates an OpenMP reduction declaration that corresponds to the given SCF
240/// reduction and returns it. Recognizes common reductions in order to identify
241/// the neutral value, necessary for the OpenMP declaration. If the reduction
242/// cannot be recognized, returns null.
243static omp::DeclareReductionOp declareReduction(PatternRewriter &builder,
244 scf::ReduceOp reduce,
245 int64_t reductionIndex) {
246 Operation *container = SymbolTable::getNearestSymbolTable(from: reduce);
247 SymbolTable symbolTable(container);
248
249 // Insert reduction declarations in the symbol-table ancestor before the
250 // ancestor of the current insertion point.
251 Operation *insertionPoint = reduce;
252 while (insertionPoint->getParentOp() != container)
253 insertionPoint = insertionPoint->getParentOp();
254 OpBuilder::InsertionGuard guard(builder);
255 builder.setInsertionPoint(insertionPoint);
256
257 assert(llvm::hasSingleElement(reduce.getReductions()[reductionIndex]) &&
258 "expected reduction region to have a single element");
259
260 // Match simple binary reductions that can be expressed with atomicrmw.
261 Type type = reduce.getOperands()[reductionIndex].getType();
262 Block &reduction = reduce.getReductions()[reductionIndex].front();
263 if (matchSimpleReduction<arith::AddFOp, LLVM::FAddOp>(block&: reduction)) {
264 omp::DeclareReductionOp decl =
265 createDecl(builder, symbolTable, reduce, reductionIndex,
266 initValue: builder.getFloatAttr(type, value: 0.0));
267 return addAtomicRMW(builder, atomicKind: LLVM::AtomicBinOp::fadd, decl, reduce,
268 reductionIndex);
269 }
270 if (matchSimpleReduction<arith::AddIOp, LLVM::AddOp>(block&: reduction)) {
271 omp::DeclareReductionOp decl =
272 createDecl(builder, symbolTable, reduce, reductionIndex,
273 initValue: builder.getIntegerAttr(type, value: 0));
274 return addAtomicRMW(builder, atomicKind: LLVM::AtomicBinOp::add, decl, reduce,
275 reductionIndex);
276 }
277 if (matchSimpleReduction<arith::OrIOp, LLVM::OrOp>(block&: reduction)) {
278 omp::DeclareReductionOp decl =
279 createDecl(builder, symbolTable, reduce, reductionIndex,
280 initValue: builder.getIntegerAttr(type, value: 0));
281 return addAtomicRMW(builder, atomicKind: LLVM::AtomicBinOp::_or, decl, reduce,
282 reductionIndex);
283 }
284 if (matchSimpleReduction<arith::XOrIOp, LLVM::XOrOp>(block&: reduction)) {
285 omp::DeclareReductionOp decl =
286 createDecl(builder, symbolTable, reduce, reductionIndex,
287 initValue: builder.getIntegerAttr(type, value: 0));
288 return addAtomicRMW(builder, atomicKind: LLVM::AtomicBinOp::_xor, decl, reduce,
289 reductionIndex);
290 }
291 if (matchSimpleReduction<arith::AndIOp, LLVM::AndOp>(block&: reduction)) {
292 omp::DeclareReductionOp decl = createDecl(
293 builder, symbolTable, reduce, reductionIndex,
294 initValue: builder.getIntegerAttr(
295 type, value: llvm::APInt::getAllOnes(numBits: type.getIntOrFloatBitWidth())));
296 return addAtomicRMW(builder, atomicKind: LLVM::AtomicBinOp::_and, decl, reduce,
297 reductionIndex);
298 }
299
300 // Match simple binary reductions that cannot be expressed with atomicrmw.
301 // TODO: add atomic region using cmpxchg (which needs atomic load to be
302 // available as an op).
303 if (matchSimpleReduction<arith::MulFOp, LLVM::FMulOp>(block&: reduction)) {
304 return createDecl(builder, symbolTable, reduce, reductionIndex,
305 initValue: builder.getFloatAttr(type, value: 1.0));
306 }
307 if (matchSimpleReduction<arith::MulIOp, LLVM::MulOp>(block&: reduction)) {
308 return createDecl(builder, symbolTable, reduce, reductionIndex,
309 initValue: builder.getIntegerAttr(type, value: 1));
310 }
311
312 // Match select-based min/max reductions.
313 bool isMin;
314 if (matchSelectReduction<arith::CmpFOp, arith::SelectOp>(
315 block&: reduction, lessThanPredicates: {arith::CmpFPredicate::OLT, arith::CmpFPredicate::OLE},
316 greaterThanPredicates: {arith::CmpFPredicate::OGT, arith::CmpFPredicate::OGE}, isMin) ||
317 matchSelectReduction<LLVM::FCmpOp, LLVM::SelectOp>(
318 block&: reduction, lessThanPredicates: {LLVM::FCmpPredicate::olt, LLVM::FCmpPredicate::ole},
319 greaterThanPredicates: {LLVM::FCmpPredicate::ogt, LLVM::FCmpPredicate::oge}, isMin)) {
320 return createDecl(builder, symbolTable, reduce, reductionIndex,
321 initValue: minMaxValueForFloat(type, min: !isMin));
322 }
323 if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>(
324 block&: reduction, lessThanPredicates: {arith::CmpIPredicate::slt, arith::CmpIPredicate::sle},
325 greaterThanPredicates: {arith::CmpIPredicate::sgt, arith::CmpIPredicate::sge}, isMin) ||
326 matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>(
327 block&: reduction, lessThanPredicates: {LLVM::ICmpPredicate::slt, LLVM::ICmpPredicate::sle},
328 greaterThanPredicates: {LLVM::ICmpPredicate::sgt, LLVM::ICmpPredicate::sge}, isMin)) {
329 omp::DeclareReductionOp decl =
330 createDecl(builder, symbolTable, reduce, reductionIndex,
331 initValue: minMaxValueForSignedInt(type, min: !isMin));
332 return addAtomicRMW(builder,
333 atomicKind: isMin ? LLVM::AtomicBinOp::min : LLVM::AtomicBinOp::max,
334 decl, reduce, reductionIndex);
335 }
336 if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>(
337 block&: reduction, lessThanPredicates: {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule},
338 greaterThanPredicates: {arith::CmpIPredicate::ugt, arith::CmpIPredicate::uge}, isMin) ||
339 matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>(
340 block&: reduction, lessThanPredicates: {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::ule},
341 greaterThanPredicates: {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::uge}, isMin)) {
342 omp::DeclareReductionOp decl =
343 createDecl(builder, symbolTable, reduce, reductionIndex,
344 initValue: minMaxValueForUnsignedInt(type, min: !isMin));
345 return addAtomicRMW(
346 builder, atomicKind: isMin ? LLVM::AtomicBinOp::umin : LLVM::AtomicBinOp::umax,
347 decl, reduce, reductionIndex);
348 }
349
350 return nullptr;
351}
352
353namespace {
354
355struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
356 static constexpr unsigned kUseOpenMPDefaultNumThreads = 0;
357 unsigned numThreads;
358
359 ParallelOpLowering(MLIRContext *context,
360 unsigned numThreads = kUseOpenMPDefaultNumThreads)
361 : OpRewritePattern<scf::ParallelOp>(context), numThreads(numThreads) {}
362
363 LogicalResult matchAndRewrite(scf::ParallelOp parallelOp,
364 PatternRewriter &rewriter) const override {
365 // Declare reductions.
366 // TODO: consider checking it here is already a compatible reduction
367 // declaration and use it instead of redeclaring.
368 SmallVector<Attribute> reductionSyms;
369 SmallVector<omp::DeclareReductionOp> ompReductionDecls;
370 auto reduce = cast<scf::ReduceOp>(Val: parallelOp.getBody()->getTerminator());
371 for (int64_t i = 0, e = parallelOp.getNumReductions(); i < e; ++i) {
372 omp::DeclareReductionOp decl = declareReduction(builder&: rewriter, reduce, reductionIndex: i);
373 ompReductionDecls.push_back(Elt: decl);
374 if (!decl)
375 return failure();
376 reductionSyms.push_back(
377 Elt: SymbolRefAttr::get(ctx: rewriter.getContext(), value: decl.getSymName()));
378 }
379
380 // Allocate reduction variables. Make sure the we don't overflow the stack
381 // with local `alloca`s by saving and restoring the stack pointer.
382 Location loc = parallelOp.getLoc();
383 Value one = rewriter.create<LLVM::ConstantOp>(
384 location: loc, args: rewriter.getIntegerType(width: 64), args: rewriter.getI64IntegerAttr(value: 1));
385 SmallVector<Value> reductionVariables;
386 reductionVariables.reserve(N: parallelOp.getNumReductions());
387 auto ptrType = LLVM::LLVMPointerType::get(context: parallelOp.getContext());
388 for (Value init : parallelOp.getInitVals()) {
389 assert((LLVM::isCompatibleType(init.getType()) ||
390 isa<LLVM::PointerElementTypeInterface>(init.getType())) &&
391 "cannot create a reduction variable if the type is not an LLVM "
392 "pointer element");
393 Value storage =
394 rewriter.create<LLVM::AllocaOp>(location: loc, args&: ptrType, args: init.getType(), args&: one, args: 0);
395 rewriter.create<LLVM::StoreOp>(location: loc, args&: init, args&: storage);
396 reductionVariables.push_back(Elt: storage);
397 }
398
399 // Replace the reduction operations contained in this loop. Must be done
400 // here rather than in a separate pattern to have access to the list of
401 // reduction variables.
402 for (auto [x, y, rD] : llvm::zip_equal(
403 t&: reductionVariables, u: reduce.getOperands(), args&: ompReductionDecls)) {
404 OpBuilder::InsertionGuard guard(rewriter);
405 rewriter.setInsertionPoint(reduce);
406 Region &redRegion = rD.getReductionRegion();
407 // The SCF dialect by definition contains only structured operations
408 // and hence the SCF reduction region will contain a single block.
409 // The ompReductionDecls region is a copy of the SCF reduction region
410 // and hence has the same property.
411 assert(redRegion.hasOneBlock() &&
412 "expect reduction region to have one block");
413 Value pvtRedVar = parallelOp.getRegion().addArgument(type: x.getType(), loc);
414 Value pvtRedVal = rewriter.create<LLVM::LoadOp>(location: reduce.getLoc(),
415 args: rD.getType(), args&: pvtRedVar);
416 // Make a copy of the reduction combiner region in the body
417 mlir::OpBuilder builder(rewriter.getContext());
418 builder.setInsertionPoint(reduce);
419 mlir::IRMapping mapper;
420 assert(redRegion.getNumArguments() == 2 &&
421 "expect reduction region to have two arguments");
422 mapper.map(from: redRegion.getArgument(i: 0), to: pvtRedVal);
423 mapper.map(from: redRegion.getArgument(i: 1), to: y);
424 for (auto &op : redRegion.getOps()) {
425 Operation *cloneOp = builder.clone(op, mapper);
426 if (auto yieldOp = dyn_cast<omp::YieldOp>(Val&: *cloneOp)) {
427 assert(yieldOp && yieldOp.getResults().size() == 1 &&
428 "expect YieldOp in reduction region to return one result");
429 Value redVal = yieldOp.getResults()[0];
430 rewriter.create<LLVM::StoreOp>(location: loc, args&: redVal, args&: pvtRedVar);
431 rewriter.eraseOp(op: yieldOp);
432 break;
433 }
434 }
435 }
436 rewriter.eraseOp(op: reduce);
437
438 Value numThreadsVar;
439 if (numThreads > 0) {
440 numThreadsVar = rewriter.create<LLVM::ConstantOp>(
441 location: loc, args: rewriter.getI32IntegerAttr(value: numThreads));
442 }
443 // Create the parallel wrapper.
444 auto ompParallel = rewriter.create<omp::ParallelOp>(
445 location: loc,
446 /* allocate_vars = */ args: llvm::SmallVector<Value>{},
447 /* allocator_vars = */ args: llvm::SmallVector<Value>{},
448 /* if_expr = */ args: Value{},
449 /* num_threads = */ args&: numThreadsVar,
450 /* private_vars = */ args: ValueRange(),
451 /* private_syms = */ args: nullptr,
452 /* private_needs_barrier = */ args: nullptr,
453 /* proc_bind_kind = */ args: omp::ClauseProcBindKindAttr{},
454 /* reduction_mod = */ args: nullptr,
455 /* reduction_vars = */ args: llvm::SmallVector<Value>{},
456 /* reduction_byref = */ args: DenseBoolArrayAttr{},
457 /* reduction_syms = */ args: ArrayAttr{});
458 {
459
460 OpBuilder::InsertionGuard guard(rewriter);
461 rewriter.createBlock(parent: &ompParallel.getRegion());
462
463 // Replace the loop.
464 {
465 OpBuilder::InsertionGuard allocaGuard(rewriter);
466 // Create worksharing loop wrapper.
467 auto wsloopOp = rewriter.create<omp::WsloopOp>(location: parallelOp.getLoc());
468 if (!reductionVariables.empty()) {
469 wsloopOp.setReductionSymsAttr(
470 ArrayAttr::get(context: rewriter.getContext(), value: reductionSyms));
471 wsloopOp.getReductionVarsMutable().append(values: reductionVariables);
472 llvm::SmallVector<bool> reductionByRef;
473 // false because these reductions always reduce scalars and so do
474 // not need to pass by reference
475 reductionByRef.resize(N: reductionVariables.size(), NV: false);
476 wsloopOp.setReductionByref(
477 DenseBoolArrayAttr::get(context: rewriter.getContext(), content: reductionByRef));
478 }
479 rewriter.create<omp::TerminatorOp>(location: loc); // omp.parallel terminator.
480
481 // The wrapper's entry block arguments will define the reduction
482 // variables.
483 llvm::SmallVector<mlir::Type> reductionTypes;
484 reductionTypes.reserve(N: reductionVariables.size());
485 llvm::transform(Range&: reductionVariables, d_first: std::back_inserter(x&: reductionTypes),
486 F: [](mlir::Value v) { return v.getType(); });
487 rewriter.createBlock(
488 parent: &wsloopOp.getRegion(), insertPt: {}, argTypes: reductionTypes,
489 locs: llvm::SmallVector<mlir::Location>(reductionVariables.size(),
490 parallelOp.getLoc()));
491
492 // Create loop nest and populate region with contents of scf.parallel.
493 auto loopOp = rewriter.create<omp::LoopNestOp>(
494 location: parallelOp.getLoc(), args: parallelOp.getLowerBound(),
495 args: parallelOp.getUpperBound(), args: parallelOp.getStep());
496
497 rewriter.inlineRegionBefore(region&: parallelOp.getRegion(), parent&: loopOp.getRegion(),
498 before: loopOp.getRegion().begin());
499
500 // Remove reduction-related block arguments from omp.loop_nest and
501 // redirect uses to the corresponding omp.wsloop block argument.
502 mlir::Block &loopOpEntryBlock = loopOp.getRegion().front();
503 unsigned numLoops = parallelOp.getNumLoops();
504 rewriter.replaceAllUsesWith(
505 from: loopOpEntryBlock.getArguments().drop_front(N: numLoops),
506 to: wsloopOp.getRegion().getArguments());
507 loopOpEntryBlock.eraseArguments(
508 start: numLoops, num: loopOpEntryBlock.getNumArguments() - numLoops);
509
510 Block *ops =
511 rewriter.splitBlock(block: &loopOpEntryBlock, before: loopOpEntryBlock.begin());
512 rewriter.setInsertionPointToStart(&loopOpEntryBlock);
513
514 auto scope = rewriter.create<memref::AllocaScopeOp>(location: parallelOp.getLoc(),
515 args: TypeRange());
516 rewriter.create<omp::YieldOp>(location: loc, args: ValueRange());
517 Block *scopeBlock = rewriter.createBlock(parent: &scope.getBodyRegion());
518 rewriter.mergeBlocks(source: ops, dest: scopeBlock);
519 rewriter.setInsertionPointToEnd(&*scope.getBodyRegion().begin());
520 rewriter.create<memref::AllocaScopeReturnOp>(location: loc, args: ValueRange());
521 }
522 }
523
524 // Load loop results.
525 SmallVector<Value> results;
526 results.reserve(N: reductionVariables.size());
527 for (auto [variable, type] :
528 llvm::zip(t&: reductionVariables, u: parallelOp.getResultTypes())) {
529 Value res = rewriter.create<LLVM::LoadOp>(location: loc, args&: type, args&: variable);
530 results.push_back(Elt: res);
531 }
532 rewriter.replaceOp(op: parallelOp, newValues: results);
533
534 return success();
535 }
536};
537
538/// Applies the conversion patterns in the given function.
539static LogicalResult applyPatterns(ModuleOp module, unsigned numThreads) {
540 ConversionTarget target(*module.getContext());
541 target.addIllegalOp<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>();
542 target.addLegalDialect<omp::OpenMPDialect, LLVM::LLVMDialect,
543 memref::MemRefDialect>();
544
545 RewritePatternSet patterns(module.getContext());
546 patterns.add<ParallelOpLowering>(arg: module.getContext(), args&: numThreads);
547 FrozenRewritePatternSet frozen(std::move(patterns));
548 return applyPartialConversion(op: module, target, patterns: frozen);
549}
550
551/// A pass converting SCF operations to OpenMP operations.
552struct SCFToOpenMPPass
553 : public impl::ConvertSCFToOpenMPPassBase<SCFToOpenMPPass> {
554
555 using Base::Base;
556
557 /// Pass entry point.
558 void runOnOperation() override {
559 if (failed(Result: applyPatterns(module: getOperation(), numThreads)))
560 signalPassFailure();
561 }
562};
563
564} // namespace
565

source code of mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp