1 | //===- TestSCFUtils.cpp --- Pass to test independent SCF dialect utils ----===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements a pass to test SCF dialect utils. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Dialect/Arith/IR/Arith.h" |
14 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
15 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
16 | #include "mlir/Dialect/SCF/IR/SCF.h" |
17 | #include "mlir/Dialect/SCF/Transforms/Patterns.h" |
18 | #include "mlir/Dialect/SCF/Utils/Utils.h" |
19 | #include "mlir/IR/Builders.h" |
20 | #include "mlir/IR/PatternMatch.h" |
21 | #include "mlir/Pass/Pass.h" |
22 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
23 | |
24 | using namespace mlir; |
25 | |
26 | namespace { |
27 | struct TestSCFForUtilsPass |
28 | : public PassWrapper<TestSCFForUtilsPass, OperationPass<func::FuncOp>> { |
29 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFForUtilsPass) |
30 | |
31 | StringRef getArgument() const final { return "test-scf-for-utils"; } |
32 | StringRef getDescription() const final { return "test scf.for utils"; } |
33 | explicit TestSCFForUtilsPass() = default; |
34 | TestSCFForUtilsPass(const TestSCFForUtilsPass &pass) : PassWrapper(pass) {} |
35 | |
36 | Option<bool> testReplaceWithNewYields{ |
37 | *this, "test-replace-with-new-yields", |
38 | llvm::cl::desc("Test replacing a loop with a new loop that returns new " |
39 | "additional yield values"), |
40 | llvm::cl::init(Val: false)}; |
41 | |
42 | void runOnOperation() override { |
43 | func::FuncOp func = getOperation(); |
44 | |
45 | if (testReplaceWithNewYields) { |
46 | func.walk([&](scf::ForOp forOp) { |
47 | if (forOp.getNumResults() == 0) |
48 | return; |
49 | auto newInitValues = forOp.getInitArgs(); |
50 | if (newInitValues.empty()) |
51 | return; |
52 | SmallVector<Value> oldYieldValues = |
53 | llvm::to_vector(forOp.getYieldedValues()); |
54 | NewYieldValuesFn fn = [&](OpBuilder &b, Location loc, |
55 | ArrayRef<BlockArgument> newBBArgs) { |
56 | SmallVector<Value> newYieldValues; |
57 | for (auto yieldVal : oldYieldValues) { |
58 | newYieldValues.push_back( |
59 | b.create<arith::AddFOp>(loc, yieldVal, yieldVal)); |
60 | } |
61 | return newYieldValues; |
62 | }; |
63 | IRRewriter rewriter(forOp.getContext()); |
64 | if (failed(forOp.replaceWithAdditionalYields( |
65 | rewriter, newInitValues, /*replaceInitOperandUsesInLoop=*/true, |
66 | fn))) |
67 | signalPassFailure(); |
68 | }); |
69 | } |
70 | } |
71 | }; |
72 | |
73 | struct TestSCFIfUtilsPass |
74 | : public PassWrapper<TestSCFIfUtilsPass, OperationPass<ModuleOp>> { |
75 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFIfUtilsPass) |
76 | |
77 | StringRef getArgument() const final { return "test-scf-if-utils"; } |
78 | StringRef getDescription() const final { return "test scf.if utils"; } |
79 | explicit TestSCFIfUtilsPass() = default; |
80 | |
81 | void getDependentDialects(DialectRegistry ®istry) const override { |
82 | registry.insert<func::FuncDialect>(); |
83 | } |
84 | |
85 | void runOnOperation() override { |
86 | int count = 0; |
87 | getOperation().walk([&](scf::IfOp ifOp) { |
88 | auto strCount = std::to_string(val: count++); |
89 | func::FuncOp thenFn, elseFn; |
90 | OpBuilder b(ifOp); |
91 | IRRewriter rewriter(b); |
92 | if (failed(outlineIfOp(rewriter, ifOp, &thenFn, |
93 | std::string("outlined_then") + strCount, &elseFn, |
94 | std::string("outlined_else") + strCount))) { |
95 | this->signalPassFailure(); |
96 | return WalkResult::interrupt(); |
97 | } |
98 | return WalkResult::advance(); |
99 | }); |
100 | } |
101 | }; |
102 | |
103 | static const StringLiteral kTestPipeliningLoopMarker = |
104 | "__test_pipelining_loop__"; |
105 | static const StringLiteral kTestPipeliningStageMarker = |
106 | "__test_pipelining_stage__"; |
107 | /// Marker to express the order in which operations should be after |
108 | /// pipelining. |
109 | static const StringLiteral kTestPipeliningOpOrderMarker = |
110 | "__test_pipelining_op_order__"; |
111 | |
112 | static const StringLiteral kTestPipeliningAnnotationPart = |
113 | "__test_pipelining_part"; |
114 | static const StringLiteral kTestPipeliningAnnotationIteration = |
115 | "__test_pipelining_iteration"; |
116 | |
117 | struct TestSCFPipeliningPass |
118 | : public PassWrapper<TestSCFPipeliningPass, OperationPass<func::FuncOp>> { |
119 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFPipeliningPass) |
120 | |
121 | TestSCFPipeliningPass() = default; |
122 | TestSCFPipeliningPass(const TestSCFPipeliningPass &) {} |
123 | StringRef getArgument() const final { return "test-scf-pipelining"; } |
124 | StringRef getDescription() const final { return "test scf.forOp pipelining"; } |
125 | |
126 | Option<bool> annotatePipeline{ |
127 | *this, "annotate", |
128 | llvm::cl::desc("Annote operations during loop pipelining transformation"), |
129 | llvm::cl::init(Val: false)}; |
130 | |
131 | Option<bool> noEpiloguePeeling{ |
132 | *this, "no-epilogue-peeling", |
133 | llvm::cl::desc("Use predicates instead of peeling the epilogue."), |
134 | llvm::cl::init(Val: false)}; |
135 | |
136 | static void |
137 | getSchedule(scf::ForOp forOp, |
138 | std::vector<std::pair<Operation *, unsigned>> &schedule) { |
139 | if (!forOp->hasAttr(kTestPipeliningLoopMarker)) |
140 | return; |
141 | |
142 | schedule.resize(forOp.getBody()->getOperations().size() - 1); |
143 | forOp.walk([&schedule](Operation *op) { |
144 | auto attrStage = |
145 | op->getAttrOfType<IntegerAttr>(kTestPipeliningStageMarker); |
146 | auto attrCycle = |
147 | op->getAttrOfType<IntegerAttr>(kTestPipeliningOpOrderMarker); |
148 | if (attrCycle && attrStage) { |
149 | // TODO: Index can be out-of-bounds if ops of the loop body disappear |
150 | // due to folding. |
151 | schedule[attrCycle.getInt()] = |
152 | std::make_pair(x&: op, y: unsigned(attrStage.getInt())); |
153 | } |
154 | }); |
155 | } |
156 | |
157 | /// Helper to generate "predicated" version of `op`. For simplicity we just |
158 | /// wrap the operation in a scf.ifOp operation. |
159 | static Operation *predicateOp(RewriterBase &rewriter, Operation *op, |
160 | Value pred) { |
161 | Location loc = op->getLoc(); |
162 | auto ifOp = |
163 | rewriter.create<scf::IfOp>(loc, op->getResultTypes(), pred, true); |
164 | // True branch. |
165 | rewriter.moveOpBefore(op, &ifOp.getThenRegion().front(), |
166 | ifOp.getThenRegion().front().begin()); |
167 | rewriter.setInsertionPointAfter(op); |
168 | if (op->getNumResults() > 0) |
169 | rewriter.create<scf::YieldOp>(loc, op->getResults()); |
170 | // False branch. |
171 | rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); |
172 | SmallVector<Value> elseYieldOperands; |
173 | elseYieldOperands.reserve(N: ifOp.getNumResults()); |
174 | if (auto viewOp = dyn_cast<memref::SubViewOp>(op)) { |
175 | // For sub-views, just clone the op. |
176 | // NOTE: This is okay in the test because we use dynamic memref sizes, so |
177 | // the verifier will not complain. Otherwise, we may create a logically |
178 | // out-of-bounds view and a different technique should be used. |
179 | Operation *opClone = rewriter.clone(op&: *op); |
180 | elseYieldOperands.append(in_start: opClone->result_begin(), in_end: opClone->result_end()); |
181 | } else { |
182 | // Default to assuming constant numeric values. |
183 | for (Type type : op->getResultTypes()) { |
184 | elseYieldOperands.push_back(rewriter.create<arith::ConstantOp>( |
185 | loc, rewriter.getZeroAttr(type))); |
186 | } |
187 | } |
188 | if (op->getNumResults() > 0) |
189 | rewriter.create<scf::YieldOp>(loc, elseYieldOperands); |
190 | return ifOp.getOperation(); |
191 | } |
192 | |
193 | static void annotate(Operation *op, |
194 | mlir::scf::PipeliningOption::PipelinerPart part, |
195 | unsigned iteration) { |
196 | OpBuilder b(op); |
197 | switch (part) { |
198 | case mlir::scf::PipeliningOption::PipelinerPart::Prologue: |
199 | op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("prologue")); |
200 | break; |
201 | case mlir::scf::PipeliningOption::PipelinerPart::Kernel: |
202 | op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("kernel")); |
203 | break; |
204 | case mlir::scf::PipeliningOption::PipelinerPart::Epilogue: |
205 | op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("epilogue")); |
206 | break; |
207 | } |
208 | op->setAttr(kTestPipeliningAnnotationIteration, |
209 | b.getI32IntegerAttr(iteration)); |
210 | } |
211 | |
212 | void getDependentDialects(DialectRegistry ®istry) const override { |
213 | registry.insert<arith::ArithDialect, memref::MemRefDialect>(); |
214 | } |
215 | |
216 | void runOnOperation() override { |
217 | RewritePatternSet patterns(&getContext()); |
218 | mlir::scf::PipeliningOption options; |
219 | options.getScheduleFn = getSchedule; |
220 | options.supportDynamicLoops = true; |
221 | options.predicateFn = predicateOp; |
222 | if (annotatePipeline) |
223 | options.annotateFn = annotate; |
224 | if (noEpiloguePeeling) { |
225 | options.peelEpilogue = false; |
226 | } |
227 | scf::populateSCFLoopPipeliningPatterns(patterns, options); |
228 | (void)applyPatternsGreedily(getOperation(), std::move(patterns)); |
229 | getOperation().walk([](Operation *op) { |
230 | // Clean up the markers. |
231 | op->removeAttr(name: kTestPipeliningStageMarker); |
232 | op->removeAttr(name: kTestPipeliningOpOrderMarker); |
233 | }); |
234 | } |
235 | }; |
236 | } // namespace |
237 | |
238 | namespace mlir { |
239 | namespace test { |
240 | void registerTestSCFUtilsPass() { |
241 | PassRegistration<TestSCFForUtilsPass>(); |
242 | PassRegistration<TestSCFIfUtilsPass>(); |
243 | PassRegistration<TestSCFPipeliningPass>(); |
244 | } |
245 | } // namespace test |
246 | } // namespace mlir |
247 |
Definitions
- TestSCFForUtilsPass
- getArgument
- getDescription
- TestSCFForUtilsPass
- TestSCFForUtilsPass
- runOnOperation
- TestSCFIfUtilsPass
- getArgument
- getDescription
- TestSCFIfUtilsPass
- getDependentDialects
- runOnOperation
- kTestPipeliningLoopMarker
- kTestPipeliningStageMarker
- kTestPipeliningOpOrderMarker
- kTestPipeliningAnnotationPart
- kTestPipeliningAnnotationIteration
- TestSCFPipeliningPass
- TestSCFPipeliningPass
- TestSCFPipeliningPass
- getArgument
- getDescription
- getSchedule
- predicateOp
- annotate
- getDependentDialects
- runOnOperation
Improve your Profiling and Debugging skills
Find out more