1 | //===- ForallToParallel.cpp - scf.forall to scf.parallel loop conversion --===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Transforms SCF.ForallOp's into SCF.ParallelOps's. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/SCF/IR/SCF.h" |
14 | #include "mlir/Dialect/SCF/Transforms/Passes.h" |
15 | #include "mlir/Dialect/SCF/Transforms/Transforms.h" |
16 | #include "mlir/IR/PatternMatch.h" |
17 | |
18 | namespace mlir { |
19 | #define GEN_PASS_DEF_SCFFORALLTOPARALLELLOOP |
20 | #include "mlir/Dialect/SCF/Transforms/Passes.h.inc" |
21 | } // namespace mlir |
22 | |
23 | using namespace mlir; |
24 | |
25 | LogicalResult mlir::scf::forallToParallelLoop(RewriterBase &rewriter, |
26 | scf::ForallOp forallOp, |
27 | scf::ParallelOp *result) { |
28 | OpBuilder::InsertionGuard guard(rewriter); |
29 | rewriter.setInsertionPoint(forallOp); |
30 | |
31 | Location loc = forallOp.getLoc(); |
32 | if (!forallOp.getOutputs().empty()) |
33 | return rewriter.notifyMatchFailure( |
34 | forallOp, |
35 | "only fully bufferized scf.forall ops can be lowered to scf.parallel" ); |
36 | |
37 | // Convert mixed bounds and steps to SSA values. |
38 | SmallVector<Value> lbs = forallOp.getLowerBound(rewriter); |
39 | SmallVector<Value> ubs = forallOp.getUpperBound(rewriter); |
40 | SmallVector<Value> steps = forallOp.getStep(rewriter); |
41 | |
42 | // Create empty scf.parallel op. |
43 | auto parallelOp = rewriter.create<scf::ParallelOp>(loc, lbs, ubs, steps); |
44 | rewriter.eraseBlock(block: ¶llelOp.getRegion().front()); |
45 | rewriter.inlineRegionBefore(forallOp.getRegion(), parallelOp.getRegion(), |
46 | parallelOp.getRegion().begin()); |
47 | // Replace the terminator. |
48 | rewriter.setInsertionPointToEnd(¶llelOp.getRegion().front()); |
49 | rewriter.replaceOpWithNewOp<scf::ReduceOp>( |
50 | parallelOp.getRegion().front().getTerminator()); |
51 | |
52 | // If the mapping attribute is present, propagate to the new parallelOp. |
53 | if (forallOp.getMapping()) |
54 | parallelOp->setAttr("mapping" , *forallOp.getMapping()); |
55 | |
56 | // Erase the scf.forall op. |
57 | rewriter.replaceOp(forallOp, parallelOp); |
58 | |
59 | if (result) |
60 | *result = parallelOp; |
61 | |
62 | return success(); |
63 | } |
64 | |
65 | namespace { |
66 | struct ForallToParallelLoop final |
67 | : public impl::SCFForallToParallelLoopBase<ForallToParallelLoop> { |
68 | void runOnOperation() override { |
69 | Operation *parentOp = getOperation(); |
70 | IRRewriter rewriter(parentOp->getContext()); |
71 | |
72 | parentOp->walk([&](scf::ForallOp forallOp) { |
73 | if (failed(scf::forallToParallelLoop(rewriter, forallOp: forallOp))) { |
74 | return signalPassFailure(); |
75 | } |
76 | }); |
77 | } |
78 | }; |
79 | } // namespace |
80 | |
81 | std::unique_ptr<Pass> mlir::createForallToParallelLoopPass() { |
82 | return std::make_unique<ForallToParallelLoop>(); |
83 | } |
84 | |