1 | //===- SCFToOpenMP.cpp - Structured Control Flow to OpenMP conversion -----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements a pass to convert scf.parallel operations into OpenMP |
10 | // parallel loops. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h" |
15 | |
16 | #include "mlir/Analysis/SliceAnalysis.h" |
17 | #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" |
18 | #include "mlir/Dialect/Arith/IR/Arith.h" |
19 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
20 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
21 | #include "mlir/Dialect/OpenMP/OpenMPDialect.h" |
22 | #include "mlir/Dialect/SCF/IR/SCF.h" |
23 | #include "mlir/IR/ImplicitLocOpBuilder.h" |
24 | #include "mlir/IR/SymbolTable.h" |
25 | #include "mlir/Pass/Pass.h" |
26 | #include "mlir/Transforms/DialectConversion.h" |
27 | |
28 | namespace mlir { |
29 | #define GEN_PASS_DEF_CONVERTSCFTOOPENMPPASS |
30 | #include "mlir/Conversion/Passes.h.inc" |
31 | } // namespace mlir |
32 | |
33 | using namespace mlir; |
34 | |
35 | /// Matches a block containing a "simple" reduction. The expected shape of the |
36 | /// block is as follows. |
37 | /// |
38 | /// ^bb(%arg0, %arg1): |
39 | /// %0 = OpTy(%arg0, %arg1) |
40 | /// scf.reduce.return %0 |
41 | template <typename... OpTy> |
42 | static bool matchSimpleReduction(Block &block) { |
43 | if (block.empty() || llvm::hasSingleElement(block) || |
44 | std::next(block.begin(), 2) != block.end()) |
45 | return false; |
46 | |
47 | if (block.getNumArguments() != 2) |
48 | return false; |
49 | |
50 | SmallVector<Operation *, 4> combinerOps; |
51 | Value reducedVal = matchReduction({block.getArguments()[1]}, |
52 | /*redPos=*/0, combinerOps); |
53 | |
54 | if (!reducedVal || !isa<BlockArgument>(reducedVal) || combinerOps.size() != 1) |
55 | return false; |
56 | |
57 | return isa<OpTy...>(combinerOps[0]) && |
58 | isa<scf::ReduceReturnOp>(block.back()) && |
59 | block.front().getOperands() == block.getArguments(); |
60 | } |
61 | |
62 | /// Matches a block containing a select-based min/max reduction. The types of |
63 | /// select and compare operations are provided as template arguments. The |
64 | /// comparison predicates suitable for min and max are provided as function |
65 | /// arguments. If a reduction is matched, `ifMin` will be set if the reduction |
66 | /// compute the minimum and unset if it computes the maximum, otherwise it |
67 | /// remains unmodified. The expected shape of the block is as follows. |
68 | /// |
69 | /// ^bb(%arg0, %arg1): |
70 | /// %0 = CompareOpTy(<one-of-predicates>, %arg0, %arg1) |
71 | /// %1 = SelectOpTy(%0, %arg0, %arg1) // %arg0, %arg1 may be swapped here. |
72 | /// scf.reduce.return %1 |
73 | template < |
74 | typename CompareOpTy, typename SelectOpTy, |
75 | typename Predicate = decltype(std::declval<CompareOpTy>().getPredicate())> |
76 | static bool |
77 | matchSelectReduction(Block &block, ArrayRef<Predicate> lessThanPredicates, |
78 | ArrayRef<Predicate> greaterThanPredicates, bool &isMin) { |
79 | static_assert( |
80 | llvm::is_one_of<SelectOpTy, arith::SelectOp, LLVM::SelectOp>::value, |
81 | "only arithmetic and llvm select ops are supported" ); |
82 | |
83 | // Expect exactly three operations in the block. |
84 | if (block.empty() || llvm::hasSingleElement(C&: block) || |
85 | std::next(x: block.begin(), n: 2) == block.end() || |
86 | std::next(x: block.begin(), n: 3) != block.end()) |
87 | return false; |
88 | |
89 | // Check op kinds. |
90 | auto compare = dyn_cast<CompareOpTy>(block.front()); |
91 | auto select = dyn_cast<SelectOpTy>(block.front().getNextNode()); |
92 | auto terminator = dyn_cast<scf::ReduceReturnOp>(block.back()); |
93 | if (!compare || !select || !terminator) |
94 | return false; |
95 | |
96 | // Block arguments must be compared. |
97 | if (compare->getOperands() != block.getArguments()) |
98 | return false; |
99 | |
100 | // Detect whether the comparison is less-than or greater-than, otherwise bail. |
101 | bool isLess; |
102 | if (llvm::is_contained(lessThanPredicates, compare.getPredicate())) { |
103 | isLess = true; |
104 | } else if (llvm::is_contained(greaterThanPredicates, |
105 | compare.getPredicate())) { |
106 | isLess = false; |
107 | } else { |
108 | return false; |
109 | } |
110 | |
111 | if (select.getCondition() != compare.getResult()) |
112 | return false; |
113 | |
114 | // Detect if the operands are swapped between cmpf and select. Match the |
115 | // comparison type with the requested type or with the opposite of the |
116 | // requested type if the operands are swapped. Use generic accessors because |
117 | // std and LLVM versions of select have different operand names but identical |
118 | // positions. |
119 | constexpr unsigned kTrueValue = 1; |
120 | constexpr unsigned kFalseValue = 2; |
121 | bool sameOperands = select.getOperand(kTrueValue) == compare.getLhs() && |
122 | select.getOperand(kFalseValue) == compare.getRhs(); |
123 | bool swappedOperands = select.getOperand(kTrueValue) == compare.getRhs() && |
124 | select.getOperand(kFalseValue) == compare.getLhs(); |
125 | if (!sameOperands && !swappedOperands) |
126 | return false; |
127 | |
128 | if (select.getResult() != terminator.getResult()) |
129 | return false; |
130 | |
131 | // The reduction is a min if it uses less-than predicates with same operands |
132 | // or greather-than predicates with swapped operands. Similarly for max. |
133 | isMin = (isLess && sameOperands) || (!isLess && swappedOperands); |
134 | return isMin || (isLess & swappedOperands) || (!isLess && sameOperands); |
135 | } |
136 | |
137 | /// Returns the float semantics for the given float type. |
138 | static const llvm::fltSemantics &fltSemanticsForType(FloatType type) { |
139 | if (type.isF16()) |
140 | return llvm::APFloat::IEEEhalf(); |
141 | if (type.isF32()) |
142 | return llvm::APFloat::IEEEsingle(); |
143 | if (type.isF64()) |
144 | return llvm::APFloat::IEEEdouble(); |
145 | if (type.isF128()) |
146 | return llvm::APFloat::IEEEquad(); |
147 | if (type.isBF16()) |
148 | return llvm::APFloat::BFloat(); |
149 | if (type.isF80()) |
150 | return llvm::APFloat::x87DoubleExtended(); |
151 | llvm_unreachable("unknown float type" ); |
152 | } |
153 | |
154 | /// Returns an attribute with the minimum (if `min` is set) or the maximum value |
155 | /// (otherwise) for the given float type. |
156 | static Attribute minMaxValueForFloat(Type type, bool min) { |
157 | auto fltType = cast<FloatType>(Val&: type); |
158 | return FloatAttr::get( |
159 | type, llvm::APFloat::getLargest(fltSemanticsForType(fltType), min)); |
160 | } |
161 | |
162 | /// Returns an attribute with the signed integer minimum (if `min` is set) or |
163 | /// the maximum value (otherwise) for the given integer type, regardless of its |
164 | /// signedness semantics (only the width is considered). |
165 | static Attribute minMaxValueForSignedInt(Type type, bool min) { |
166 | auto intType = cast<IntegerType>(type); |
167 | unsigned bitwidth = intType.getWidth(); |
168 | return IntegerAttr::get(type, min ? llvm::APInt::getSignedMinValue(bitwidth) |
169 | : llvm::APInt::getSignedMaxValue(bitwidth)); |
170 | } |
171 | |
172 | /// Returns an attribute with the unsigned integer minimum (if `min` is set) or |
173 | /// the maximum value (otherwise) for the given integer type, regardless of its |
174 | /// signedness semantics (only the width is considered). |
175 | static Attribute minMaxValueForUnsignedInt(Type type, bool min) { |
176 | auto intType = cast<IntegerType>(type); |
177 | unsigned bitwidth = intType.getWidth(); |
178 | return IntegerAttr::get(type, min ? llvm::APInt::getZero(bitwidth) |
179 | : llvm::APInt::getAllOnes(bitwidth)); |
180 | } |
181 | |
182 | /// Creates an OpenMP reduction declaration and inserts it into the provided |
183 | /// symbol table. The declaration has a constant initializer with the neutral |
184 | /// value `initValue`, and the `reductionIndex`-th reduction combiner carried |
185 | /// over from `reduce`. |
186 | static omp::DeclareReductionOp |
187 | createDecl(PatternRewriter &builder, SymbolTable &symbolTable, |
188 | scf::ReduceOp reduce, int64_t reductionIndex, Attribute initValue) { |
189 | OpBuilder::InsertionGuard guard(builder); |
190 | Type type = reduce.getOperands()[reductionIndex].getType(); |
191 | auto decl = builder.create<omp::DeclareReductionOp>(reduce.getLoc(), |
192 | "__scf_reduction" , type); |
193 | symbolTable.insert(symbol: decl); |
194 | |
195 | builder.createBlock(&decl.getInitializerRegion(), |
196 | decl.getInitializerRegion().end(), {type}, |
197 | {reduce.getOperands()[reductionIndex].getLoc()}); |
198 | builder.setInsertionPointToEnd(&decl.getInitializerRegion().back()); |
199 | Value init = |
200 | builder.create<LLVM::ConstantOp>(reduce.getLoc(), type, initValue); |
201 | builder.create<omp::YieldOp>(reduce.getLoc(), init); |
202 | |
203 | Operation *terminator = |
204 | &reduce.getReductions()[reductionIndex].front().back(); |
205 | assert(isa<scf::ReduceReturnOp>(terminator) && |
206 | "expected reduce op to be terminated by redure return" ); |
207 | builder.setInsertionPoint(terminator); |
208 | builder.replaceOpWithNewOp<omp::YieldOp>(terminator, |
209 | terminator->getOperands()); |
210 | builder.inlineRegionBefore(reduce.getReductions()[reductionIndex], |
211 | decl.getReductionRegion(), |
212 | decl.getReductionRegion().end()); |
213 | return decl; |
214 | } |
215 | |
216 | /// Adds an atomic reduction combiner to the given OpenMP reduction declaration |
217 | /// using llvm.atomicrmw of the given kind. |
218 | static omp::DeclareReductionOp addAtomicRMW(OpBuilder &builder, |
219 | LLVM::AtomicBinOp atomicKind, |
220 | omp::DeclareReductionOp decl, |
221 | scf::ReduceOp reduce, |
222 | int64_t reductionIndex) { |
223 | OpBuilder::InsertionGuard guard(builder); |
224 | auto ptrType = LLVM::LLVMPointerType::get(builder.getContext()); |
225 | Location reduceOperandLoc = reduce.getOperands()[reductionIndex].getLoc(); |
226 | builder.createBlock(&decl.getAtomicReductionRegion(), |
227 | decl.getAtomicReductionRegion().end(), {ptrType, ptrType}, |
228 | {reduceOperandLoc, reduceOperandLoc}); |
229 | Block *atomicBlock = &decl.getAtomicReductionRegion().back(); |
230 | builder.setInsertionPointToEnd(atomicBlock); |
231 | Value loaded = builder.create<LLVM::LoadOp>(reduce.getLoc(), decl.getType(), |
232 | atomicBlock->getArgument(1)); |
233 | builder.create<LLVM::AtomicRMWOp>(reduce.getLoc(), atomicKind, |
234 | atomicBlock->getArgument(0), loaded, |
235 | LLVM::AtomicOrdering::monotonic); |
236 | builder.create<omp::YieldOp>(reduce.getLoc(), ArrayRef<Value>()); |
237 | return decl; |
238 | } |
239 | |
240 | /// Creates an OpenMP reduction declaration that corresponds to the given SCF |
241 | /// reduction and returns it. Recognizes common reductions in order to identify |
242 | /// the neutral value, necessary for the OpenMP declaration. If the reduction |
243 | /// cannot be recognized, returns null. |
244 | static omp::DeclareReductionOp declareReduction(PatternRewriter &builder, |
245 | scf::ReduceOp reduce, |
246 | int64_t reductionIndex) { |
247 | Operation *container = SymbolTable::getNearestSymbolTable(from: reduce); |
248 | SymbolTable symbolTable(container); |
249 | |
250 | // Insert reduction declarations in the symbol-table ancestor before the |
251 | // ancestor of the current insertion point. |
252 | Operation *insertionPoint = reduce; |
253 | while (insertionPoint->getParentOp() != container) |
254 | insertionPoint = insertionPoint->getParentOp(); |
255 | OpBuilder::InsertionGuard guard(builder); |
256 | builder.setInsertionPoint(insertionPoint); |
257 | |
258 | assert(llvm::hasSingleElement(reduce.getReductions()[reductionIndex]) && |
259 | "expected reduction region to have a single element" ); |
260 | |
261 | // Match simple binary reductions that can be expressed with atomicrmw. |
262 | Type type = reduce.getOperands()[reductionIndex].getType(); |
263 | Block &reduction = reduce.getReductions()[reductionIndex].front(); |
264 | if (matchSimpleReduction<arith::AddFOp, LLVM::FAddOp>(reduction)) { |
265 | omp::DeclareReductionOp decl = |
266 | createDecl(builder, symbolTable, reduce, reductionIndex, |
267 | builder.getFloatAttr(type, 0.0)); |
268 | return addAtomicRMW(builder, LLVM::AtomicBinOp::fadd, decl, reduce, |
269 | reductionIndex); |
270 | } |
271 | if (matchSimpleReduction<arith::AddIOp, LLVM::AddOp>(reduction)) { |
272 | omp::DeclareReductionOp decl = |
273 | createDecl(builder, symbolTable, reduce, reductionIndex, |
274 | builder.getIntegerAttr(type, 0)); |
275 | return addAtomicRMW(builder, LLVM::AtomicBinOp::add, decl, reduce, |
276 | reductionIndex); |
277 | } |
278 | if (matchSimpleReduction<arith::OrIOp, LLVM::OrOp>(reduction)) { |
279 | omp::DeclareReductionOp decl = |
280 | createDecl(builder, symbolTable, reduce, reductionIndex, |
281 | builder.getIntegerAttr(type, 0)); |
282 | return addAtomicRMW(builder, LLVM::AtomicBinOp::_or, decl, reduce, |
283 | reductionIndex); |
284 | } |
285 | if (matchSimpleReduction<arith::XOrIOp, LLVM::XOrOp>(reduction)) { |
286 | omp::DeclareReductionOp decl = |
287 | createDecl(builder, symbolTable, reduce, reductionIndex, |
288 | builder.getIntegerAttr(type, 0)); |
289 | return addAtomicRMW(builder, LLVM::AtomicBinOp::_xor, decl, reduce, |
290 | reductionIndex); |
291 | } |
292 | if (matchSimpleReduction<arith::AndIOp, LLVM::AndOp>(reduction)) { |
293 | omp::DeclareReductionOp decl = createDecl( |
294 | builder, symbolTable, reduce, reductionIndex, |
295 | builder.getIntegerAttr( |
296 | type, llvm::APInt::getAllOnes(type.getIntOrFloatBitWidth()))); |
297 | return addAtomicRMW(builder, LLVM::AtomicBinOp::_and, decl, reduce, |
298 | reductionIndex); |
299 | } |
300 | |
301 | // Match simple binary reductions that cannot be expressed with atomicrmw. |
302 | // TODO: add atomic region using cmpxchg (which needs atomic load to be |
303 | // available as an op). |
304 | if (matchSimpleReduction<arith::MulFOp, LLVM::FMulOp>(reduction)) { |
305 | return createDecl(builder, symbolTable, reduce, reductionIndex, |
306 | builder.getFloatAttr(type, 1.0)); |
307 | } |
308 | if (matchSimpleReduction<arith::MulIOp, LLVM::MulOp>(reduction)) { |
309 | return createDecl(builder, symbolTable, reduce, reductionIndex, |
310 | builder.getIntegerAttr(type, 1)); |
311 | } |
312 | |
313 | // Match select-based min/max reductions. |
314 | bool isMin; |
315 | if (matchSelectReduction<arith::CmpFOp, arith::SelectOp>( |
316 | reduction, {arith::CmpFPredicate::OLT, arith::CmpFPredicate::OLE}, |
317 | {arith::CmpFPredicate::OGT, arith::CmpFPredicate::OGE}, isMin) || |
318 | matchSelectReduction<LLVM::FCmpOp, LLVM::SelectOp>( |
319 | reduction, {LLVM::FCmpPredicate::olt, LLVM::FCmpPredicate::ole}, |
320 | {LLVM::FCmpPredicate::ogt, LLVM::FCmpPredicate::oge}, isMin)) { |
321 | return createDecl(builder, symbolTable, reduce, reductionIndex, |
322 | minMaxValueForFloat(type, !isMin)); |
323 | } |
324 | if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>( |
325 | reduction, {arith::CmpIPredicate::slt, arith::CmpIPredicate::sle}, |
326 | {arith::CmpIPredicate::sgt, arith::CmpIPredicate::sge}, isMin) || |
327 | matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>( |
328 | reduction, {LLVM::ICmpPredicate::slt, LLVM::ICmpPredicate::sle}, |
329 | {LLVM::ICmpPredicate::sgt, LLVM::ICmpPredicate::sge}, isMin)) { |
330 | omp::DeclareReductionOp decl = |
331 | createDecl(builder, symbolTable, reduce, reductionIndex, |
332 | minMaxValueForSignedInt(type, !isMin)); |
333 | return addAtomicRMW(builder, |
334 | isMin ? LLVM::AtomicBinOp::min : LLVM::AtomicBinOp::max, |
335 | decl, reduce, reductionIndex); |
336 | } |
337 | if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>( |
338 | reduction, {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule}, |
339 | {arith::CmpIPredicate::ugt, arith::CmpIPredicate::uge}, isMin) || |
340 | matchSelectReduction<LLVM::ICmpOp, LLVM::SelectOp>( |
341 | reduction, {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::ule}, |
342 | {LLVM::ICmpPredicate::ugt, LLVM::ICmpPredicate::uge}, isMin)) { |
343 | omp::DeclareReductionOp decl = |
344 | createDecl(builder, symbolTable, reduce, reductionIndex, |
345 | minMaxValueForUnsignedInt(type, !isMin)); |
346 | return addAtomicRMW( |
347 | builder, isMin ? LLVM::AtomicBinOp::umin : LLVM::AtomicBinOp::umax, |
348 | decl, reduce, reductionIndex); |
349 | } |
350 | |
351 | return nullptr; |
352 | } |
353 | |
354 | namespace { |
355 | |
356 | struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> { |
357 | static constexpr unsigned kUseOpenMPDefaultNumThreads = 0; |
358 | unsigned numThreads; |
359 | |
360 | ParallelOpLowering(MLIRContext *context, |
361 | unsigned numThreads = kUseOpenMPDefaultNumThreads) |
362 | : OpRewritePattern<scf::ParallelOp>(context), numThreads(numThreads) {} |
363 | |
364 | LogicalResult matchAndRewrite(scf::ParallelOp parallelOp, |
365 | PatternRewriter &rewriter) const override { |
366 | // Declare reductions. |
367 | // TODO: consider checking it here is already a compatible reduction |
368 | // declaration and use it instead of redeclaring. |
369 | SmallVector<Attribute> reductionDeclSymbols; |
370 | SmallVector<omp::DeclareReductionOp> ompReductionDecls; |
371 | auto reduce = cast<scf::ReduceOp>(parallelOp.getBody()->getTerminator()); |
372 | for (int64_t i = 0, e = parallelOp.getNumReductions(); i < e; ++i) { |
373 | omp::DeclareReductionOp decl = declareReduction(rewriter, reduce, i); |
374 | ompReductionDecls.push_back(decl); |
375 | if (!decl) |
376 | return failure(); |
377 | reductionDeclSymbols.push_back( |
378 | SymbolRefAttr::get(rewriter.getContext(), decl.getSymName())); |
379 | } |
380 | |
381 | // Allocate reduction variables. Make sure the we don't overflow the stack |
382 | // with local `alloca`s by saving and restoring the stack pointer. |
383 | Location loc = parallelOp.getLoc(); |
384 | Value one = rewriter.create<LLVM::ConstantOp>( |
385 | loc, rewriter.getIntegerType(64), rewriter.getI64IntegerAttr(1)); |
386 | SmallVector<Value> reductionVariables; |
387 | reductionVariables.reserve(N: parallelOp.getNumReductions()); |
388 | auto ptrType = LLVM::LLVMPointerType::get(parallelOp.getContext()); |
389 | for (Value init : parallelOp.getInitVals()) { |
390 | assert((LLVM::isCompatibleType(init.getType()) || |
391 | isa<LLVM::PointerElementTypeInterface>(init.getType())) && |
392 | "cannot create a reduction variable if the type is not an LLVM " |
393 | "pointer element" ); |
394 | Value storage = |
395 | rewriter.create<LLVM::AllocaOp>(loc, ptrType, init.getType(), one, 0); |
396 | rewriter.create<LLVM::StoreOp>(loc, init, storage); |
397 | reductionVariables.push_back(storage); |
398 | } |
399 | |
400 | // Replace the reduction operations contained in this loop. Must be done |
401 | // here rather than in a separate pattern to have access to the list of |
402 | // reduction variables. |
403 | for (auto [x, y, rD] : llvm::zip_equal( |
404 | reductionVariables, reduce.getOperands(), ompReductionDecls)) { |
405 | OpBuilder::InsertionGuard guard(rewriter); |
406 | rewriter.setInsertionPoint(reduce); |
407 | Region &redRegion = rD.getReductionRegion(); |
408 | // The SCF dialect by definition contains only structured operations |
409 | // and hence the SCF reduction region will contain a single block. |
410 | // The ompReductionDecls region is a copy of the SCF reduction region |
411 | // and hence has the same property. |
412 | assert(redRegion.hasOneBlock() && |
413 | "expect reduction region to have one block" ); |
414 | Value pvtRedVar = parallelOp.getRegion().addArgument(x.getType(), loc); |
415 | Value pvtRedVal = rewriter.create<LLVM::LoadOp>(reduce.getLoc(), |
416 | rD.getType(), pvtRedVar); |
417 | // Make a copy of the reduction combiner region in the body |
418 | mlir::OpBuilder builder(rewriter.getContext()); |
419 | builder.setInsertionPoint(reduce); |
420 | mlir::IRMapping mapper; |
421 | assert(redRegion.getNumArguments() == 2 && |
422 | "expect reduction region to have two arguments" ); |
423 | mapper.map(redRegion.getArgument(0), pvtRedVal); |
424 | mapper.map(redRegion.getArgument(1), y); |
425 | for (auto &op : redRegion.getOps()) { |
426 | Operation *cloneOp = builder.clone(op, mapper); |
427 | if (auto yieldOp = dyn_cast<omp::YieldOp>(*cloneOp)) { |
428 | assert(yieldOp && yieldOp.getResults().size() == 1 && |
429 | "expect YieldOp in reduction region to return one result" ); |
430 | Value redVal = yieldOp.getResults()[0]; |
431 | rewriter.create<LLVM::StoreOp>(loc, redVal, pvtRedVar); |
432 | rewriter.eraseOp(yieldOp); |
433 | break; |
434 | } |
435 | } |
436 | } |
437 | rewriter.eraseOp(op: reduce); |
438 | |
439 | Value numThreadsVar; |
440 | if (numThreads > 0) { |
441 | numThreadsVar = rewriter.create<LLVM::ConstantOp>( |
442 | loc, rewriter.getI32IntegerAttr(numThreads)); |
443 | } |
444 | // Create the parallel wrapper. |
445 | auto ompParallel = rewriter.create<omp::ParallelOp>( |
446 | loc, |
447 | /* if_expr_var = */ Value{}, |
448 | /* num_threads_var = */ numThreadsVar, |
449 | /* allocate_vars = */ llvm::SmallVector<Value>{}, |
450 | /* allocators_vars = */ llvm::SmallVector<Value>{}, |
451 | /* reduction_vars = */ llvm::SmallVector<Value>{}, |
452 | /* reductions = */ ArrayAttr{}, |
453 | /* proc_bind_val = */ omp::ClauseProcBindKindAttr{}, |
454 | /* private_vars = */ ValueRange(), |
455 | /* privatizers = */ nullptr); |
456 | { |
457 | |
458 | OpBuilder::InsertionGuard guard(rewriter); |
459 | rewriter.createBlock(&ompParallel.getRegion()); |
460 | |
461 | // Replace the loop. |
462 | { |
463 | OpBuilder::InsertionGuard allocaGuard(rewriter); |
464 | // Create worksharing loop wrapper. |
465 | auto wsloopOp = rewriter.create<omp::WsloopOp>(parallelOp.getLoc()); |
466 | if (!reductionVariables.empty()) { |
467 | wsloopOp.setReductionsAttr( |
468 | ArrayAttr::get(rewriter.getContext(), reductionDeclSymbols)); |
469 | wsloopOp.getReductionVarsMutable().append(reductionVariables); |
470 | } |
471 | rewriter.create<omp::TerminatorOp>(loc); // omp.parallel terminator. |
472 | |
473 | // The wrapper's entry block arguments will define the reduction |
474 | // variables. |
475 | llvm::SmallVector<mlir::Type> reductionTypes; |
476 | reductionTypes.reserve(N: reductionVariables.size()); |
477 | llvm::transform(Range&: reductionVariables, d_first: std::back_inserter(x&: reductionTypes), |
478 | F: [](mlir::Value v) { return v.getType(); }); |
479 | rewriter.createBlock( |
480 | &wsloopOp.getRegion(), {}, reductionTypes, |
481 | llvm::SmallVector<mlir::Location>(reductionVariables.size(), |
482 | parallelOp.getLoc())); |
483 | |
484 | rewriter.setInsertionPoint( |
485 | rewriter.create<omp::TerminatorOp>(parallelOp.getLoc())); |
486 | |
487 | // Create loop nest and populate region with contents of scf.parallel. |
488 | auto loopOp = rewriter.create<omp::LoopNestOp>( |
489 | parallelOp.getLoc(), parallelOp.getLowerBound(), |
490 | parallelOp.getUpperBound(), parallelOp.getStep()); |
491 | |
492 | rewriter.inlineRegionBefore(parallelOp.getRegion(), loopOp.getRegion(), |
493 | loopOp.getRegion().begin()); |
494 | |
495 | // Remove reduction-related block arguments from omp.loop_nest and |
496 | // redirect uses to the corresponding omp.wsloop block argument. |
497 | mlir::Block &loopOpEntryBlock = loopOp.getRegion().front(); |
498 | unsigned numLoops = parallelOp.getNumLoops(); |
499 | rewriter.replaceAllUsesWith( |
500 | loopOpEntryBlock.getArguments().drop_front(N: numLoops), |
501 | wsloopOp.getRegion().getArguments()); |
502 | loopOpEntryBlock.eraseArguments( |
503 | start: numLoops, num: loopOpEntryBlock.getNumArguments() - numLoops); |
504 | |
505 | Block *ops = |
506 | rewriter.splitBlock(block: &loopOpEntryBlock, before: loopOpEntryBlock.begin()); |
507 | rewriter.setInsertionPointToStart(&loopOpEntryBlock); |
508 | |
509 | auto scope = rewriter.create<memref::AllocaScopeOp>(parallelOp.getLoc(), |
510 | TypeRange()); |
511 | rewriter.create<omp::YieldOp>(loc, ValueRange()); |
512 | Block *scopeBlock = rewriter.createBlock(&scope.getBodyRegion()); |
513 | rewriter.mergeBlocks(source: ops, dest: scopeBlock); |
514 | rewriter.setInsertionPointToEnd(&*scope.getBodyRegion().begin()); |
515 | rewriter.create<memref::AllocaScopeReturnOp>(loc, ValueRange()); |
516 | } |
517 | } |
518 | |
519 | // Load loop results. |
520 | SmallVector<Value> results; |
521 | results.reserve(N: reductionVariables.size()); |
522 | for (auto [variable, type] : |
523 | llvm::zip(reductionVariables, parallelOp.getResultTypes())) { |
524 | Value res = rewriter.create<LLVM::LoadOp>(loc, type, variable); |
525 | results.push_back(res); |
526 | } |
527 | rewriter.replaceOp(parallelOp, results); |
528 | |
529 | return success(); |
530 | } |
531 | }; |
532 | |
533 | /// Applies the conversion patterns in the given function. |
534 | static LogicalResult applyPatterns(ModuleOp module, unsigned numThreads) { |
535 | ConversionTarget target(*module.getContext()); |
536 | target.addIllegalOp<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>(); |
537 | target.addLegalDialect<omp::OpenMPDialect, LLVM::LLVMDialect, |
538 | memref::MemRefDialect>(); |
539 | |
540 | RewritePatternSet patterns(module.getContext()); |
541 | patterns.add<ParallelOpLowering>(module.getContext(), numThreads); |
542 | FrozenRewritePatternSet frozen(std::move(patterns)); |
543 | return applyPartialConversion(module, target, frozen); |
544 | } |
545 | |
546 | /// A pass converting SCF operations to OpenMP operations. |
547 | struct SCFToOpenMPPass |
548 | : public impl::ConvertSCFToOpenMPPassBase<SCFToOpenMPPass> { |
549 | |
550 | using Base::Base; |
551 | |
552 | /// Pass entry point. |
553 | void runOnOperation() override { |
554 | if (failed(applyPatterns(getOperation(), numThreads))) |
555 | signalPassFailure(); |
556 | } |
557 | }; |
558 | |
559 | } // namespace |
560 | |