1//===-------- TestLoopUnrolling.cpp --- loop unrolling test pass ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to unroll loops by a specified unroll factor.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Arith/IR/Arith.h"
14#include "mlir/Dialect/SCF/IR/SCF.h"
15#include "mlir/Dialect/SCF/Utils/Utils.h"
16#include "mlir/IR/Builders.h"
17#include "mlir/Pass/Pass.h"
18
19using namespace mlir;
20
21namespace {
22
23static unsigned getNestingDepth(Operation *op) {
24 Operation *currOp = op;
25 unsigned depth = 0;
26 while ((currOp = currOp->getParentOp())) {
27 if (isa<scf::ForOp>(currOp))
28 depth++;
29 }
30 return depth;
31}
32
33struct TestLoopUnrollingPass
34 : public PassWrapper<TestLoopUnrollingPass, OperationPass<>> {
35 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLoopUnrollingPass)
36
37 StringRef getArgument() const final { return "test-loop-unrolling"; }
38 StringRef getDescription() const final {
39 return "Tests loop unrolling transformation";
40 }
41 TestLoopUnrollingPass() = default;
42 TestLoopUnrollingPass(const TestLoopUnrollingPass &) {}
43 explicit TestLoopUnrollingPass(uint64_t unrollFactorParam,
44 unsigned loopDepthParam,
45 bool annotateLoopParam, bool unrollFullParam) {
46 unrollFactor = unrollFactorParam;
47 loopDepth = loopDepthParam;
48 annotateLoop = annotateLoopParam;
49 unrollFull = unrollFactorParam;
50 }
51
52 void getDependentDialects(DialectRegistry &registry) const override {
53 registry.insert<arith::ArithDialect>();
54 }
55
56 void runOnOperation() override {
57 SmallVector<scf::ForOp, 4> loops;
58 getOperation()->walk(callback: [&](scf::ForOp forOp) {
59 if (getNestingDepth(forOp) == loopDepth)
60 loops.push_back(forOp);
61 });
62 auto annotateFn = [this](unsigned i, Operation *op, OpBuilder b) {
63 if (annotateLoop) {
64 op->setAttr("unrolled_iteration", b.getUI32IntegerAttr(i));
65 }
66 };
67 for (auto loop : loops) {
68 if (unrollFull)
69 (void)loopUnrollFull(loop);
70 else
71 (void)loopUnrollByFactor(loop, unrollFactor, annotateFn);
72 }
73 }
74 Option<uint64_t> unrollFactor{*this, "unroll-factor",
75 llvm::cl::desc("Loop unroll factor."),
76 llvm::cl::init(Val: 1)};
77 Option<bool> annotateLoop{*this, "annotate",
78 llvm::cl::desc("Annotate unrolled iterations."),
79 llvm::cl::init(Val: false)};
80 Option<bool> unrollUpToFactor{*this, "unroll-up-to-factor",
81 llvm::cl::desc("Loop unroll up to factor."),
82 llvm::cl::init(Val: false)};
83 Option<unsigned> loopDepth{*this, "loop-depth", llvm::cl::desc("Loop depth."),
84 llvm::cl::init(Val: 0)};
85 Option<bool> unrollFull{*this, "unroll-full",
86 llvm::cl::desc("Full unroll loops."),
87 llvm::cl::init(Val: false)};
88};
89} // namespace
90
91namespace mlir {
92namespace test {
93void registerTestLoopUnrollingPass() {
94 PassRegistration<TestLoopUnrollingPass>();
95}
96} // namespace test
97} // namespace mlir
98

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/test/lib/Dialect/SCF/TestLoopUnrolling.cpp