1//===- ForallToFor.cpp - scf.forall to scf.for loop conversion ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Transforms SCF.ForallOp's into SCF.ForOp's.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/SCF/Transforms/Passes.h"
14
15#include "mlir/Dialect/SCF/IR/SCF.h"
16#include "mlir/Dialect/SCF/Transforms/Transforms.h"
17#include "mlir/IR/PatternMatch.h"
18
19namespace mlir {
20#define GEN_PASS_DEF_SCFFORALLTOFORLOOP
21#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
22} // namespace mlir
23
24using namespace llvm;
25using namespace mlir;
26using scf::LoopNest;
27
28LogicalResult
29mlir::scf::forallToForLoop(RewriterBase &rewriter, scf::ForallOp forallOp,
30 SmallVectorImpl<Operation *> *results) {
31 OpBuilder::InsertionGuard guard(rewriter);
32 rewriter.setInsertionPoint(forallOp);
33
34 Location loc = forallOp.getLoc();
35 SmallVector<Value> lbs = forallOp.getLowerBound(rewriter);
36 SmallVector<Value> ubs = forallOp.getUpperBound(rewriter);
37 SmallVector<Value> steps = forallOp.getStep(rewriter);
38 LoopNest loopNest = scf::buildLoopNest(rewriter, loc, lbs, ubs, steps);
39
40 SmallVector<Value> ivs = llvm::map_to_vector(
41 loopNest.loops, [](scf::ForOp loop) { return loop.getInductionVar(); });
42
43 Block *innermostBlock = loopNest.loops.back().getBody();
44 rewriter.eraseOp(op: forallOp.getBody()->getTerminator());
45 rewriter.inlineBlockBefore(forallOp.getBody(), innermostBlock,
46 innermostBlock->getTerminator()->getIterator(),
47 ivs);
48 rewriter.eraseOp(op: forallOp);
49
50 if (results) {
51 llvm::move(loopNest.loops, std::back_inserter(*results));
52 }
53
54 return success();
55}
56
57namespace {
58struct ForallToForLoop : public impl::SCFForallToForLoopBase<ForallToForLoop> {
59 void runOnOperation() override {
60 Operation *parentOp = getOperation();
61 IRRewriter rewriter(parentOp->getContext());
62
63 parentOp->walk([&](scf::ForallOp forallOp) {
64 if (failed(scf::forallToForLoop(rewriter, forallOp: forallOp))) {
65 return signalPassFailure();
66 }
67 });
68 }
69};
70} // namespace
71
72std::unique_ptr<Pass> mlir::createForallToForLoopPass() {
73 return std::make_unique<ForallToForLoop>();
74}
75

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp