1//===- TestSCFUtils.cpp --- Pass to test independent SCF dialect utils ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to test SCF dialect utils.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Arith/IR/Arith.h"
14#include "mlir/Dialect/Func/IR/FuncOps.h"
15#include "mlir/Dialect/MemRef/IR/MemRef.h"
16#include "mlir/Dialect/SCF/IR/SCF.h"
17#include "mlir/Dialect/SCF/Transforms/Patterns.h"
18#include "mlir/Dialect/SCF/Utils/Utils.h"
19#include "mlir/IR/Builders.h"
20#include "mlir/IR/PatternMatch.h"
21#include "mlir/Pass/Pass.h"
22#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23
24using namespace mlir;
25
26namespace {
27struct TestSCFForUtilsPass
28 : public PassWrapper<TestSCFForUtilsPass, OperationPass<func::FuncOp>> {
29 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFForUtilsPass)
30
31 StringRef getArgument() const final { return "test-scf-for-utils"; }
32 StringRef getDescription() const final { return "test scf.for utils"; }
33 explicit TestSCFForUtilsPass() = default;
34 TestSCFForUtilsPass(const TestSCFForUtilsPass &pass) : PassWrapper(pass) {}
35
36 Option<bool> testReplaceWithNewYields{
37 *this, "test-replace-with-new-yields",
38 llvm::cl::desc("Test replacing a loop with a new loop that returns new "
39 "additional yield values"),
40 llvm::cl::init(Val: false)};
41
42 void runOnOperation() override {
43 func::FuncOp func = getOperation();
44
45 if (testReplaceWithNewYields) {
46 func.walk([&](scf::ForOp forOp) {
47 if (forOp.getNumResults() == 0)
48 return;
49 auto newInitValues = forOp.getInitArgs();
50 if (newInitValues.empty())
51 return;
52 SmallVector<Value> oldYieldValues =
53 llvm::to_vector(forOp.getYieldedValues());
54 NewYieldValuesFn fn = [&](OpBuilder &b, Location loc,
55 ArrayRef<BlockArgument> newBBArgs) {
56 SmallVector<Value> newYieldValues;
57 for (auto yieldVal : oldYieldValues) {
58 newYieldValues.push_back(
59 b.create<arith::AddFOp>(loc, yieldVal, yieldVal));
60 }
61 return newYieldValues;
62 };
63 IRRewriter rewriter(forOp.getContext());
64 if (failed(forOp.replaceWithAdditionalYields(
65 rewriter, newInitValues, /*replaceInitOperandUsesInLoop=*/true,
66 fn)))
67 signalPassFailure();
68 });
69 }
70 }
71};
72
73struct TestSCFIfUtilsPass
74 : public PassWrapper<TestSCFIfUtilsPass, OperationPass<ModuleOp>> {
75 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFIfUtilsPass)
76
77 StringRef getArgument() const final { return "test-scf-if-utils"; }
78 StringRef getDescription() const final { return "test scf.if utils"; }
79 explicit TestSCFIfUtilsPass() = default;
80
81 void getDependentDialects(DialectRegistry &registry) const override {
82 registry.insert<func::FuncDialect>();
83 }
84
85 void runOnOperation() override {
86 int count = 0;
87 getOperation().walk([&](scf::IfOp ifOp) {
88 auto strCount = std::to_string(val: count++);
89 func::FuncOp thenFn, elseFn;
90 OpBuilder b(ifOp);
91 IRRewriter rewriter(b);
92 if (failed(outlineIfOp(rewriter, ifOp, &thenFn,
93 std::string("outlined_then") + strCount, &elseFn,
94 std::string("outlined_else") + strCount))) {
95 this->signalPassFailure();
96 return WalkResult::interrupt();
97 }
98 return WalkResult::advance();
99 });
100 }
101};
102
103static const StringLiteral kTestPipeliningLoopMarker =
104 "__test_pipelining_loop__";
105static const StringLiteral kTestPipeliningStageMarker =
106 "__test_pipelining_stage__";
107/// Marker to express the order in which operations should be after
108/// pipelining.
109static const StringLiteral kTestPipeliningOpOrderMarker =
110 "__test_pipelining_op_order__";
111
112static const StringLiteral kTestPipeliningAnnotationPart =
113 "__test_pipelining_part";
114static const StringLiteral kTestPipeliningAnnotationIteration =
115 "__test_pipelining_iteration";
116
117struct TestSCFPipeliningPass
118 : public PassWrapper<TestSCFPipeliningPass, OperationPass<func::FuncOp>> {
119 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSCFPipeliningPass)
120
121 TestSCFPipeliningPass() = default;
122 TestSCFPipeliningPass(const TestSCFPipeliningPass &) {}
123 StringRef getArgument() const final { return "test-scf-pipelining"; }
124 StringRef getDescription() const final { return "test scf.forOp pipelining"; }
125
126 Option<bool> annotatePipeline{
127 *this, "annotate",
128 llvm::cl::desc("Annote operations during loop pipelining transformation"),
129 llvm::cl::init(Val: false)};
130
131 Option<bool> noEpiloguePeeling{
132 *this, "no-epilogue-peeling",
133 llvm::cl::desc("Use predicates instead of peeling the epilogue."),
134 llvm::cl::init(Val: false)};
135
136 static void
137 getSchedule(scf::ForOp forOp,
138 std::vector<std::pair<Operation *, unsigned>> &schedule) {
139 if (!forOp->hasAttr(kTestPipeliningLoopMarker))
140 return;
141
142 schedule.resize(forOp.getBody()->getOperations().size() - 1);
143 forOp.walk([&schedule](Operation *op) {
144 auto attrStage =
145 op->getAttrOfType<IntegerAttr>(kTestPipeliningStageMarker);
146 auto attrCycle =
147 op->getAttrOfType<IntegerAttr>(kTestPipeliningOpOrderMarker);
148 if (attrCycle && attrStage) {
149 // TODO: Index can be out-of-bounds if ops of the loop body disappear
150 // due to folding.
151 schedule[attrCycle.getInt()] =
152 std::make_pair(x&: op, y: unsigned(attrStage.getInt()));
153 }
154 });
155 }
156
157 /// Helper to generate "predicated" version of `op`. For simplicity we just
158 /// wrap the operation in a scf.ifOp operation.
159 static Operation *predicateOp(RewriterBase &rewriter, Operation *op,
160 Value pred) {
161 Location loc = op->getLoc();
162 auto ifOp =
163 rewriter.create<scf::IfOp>(loc, op->getResultTypes(), pred, true);
164 // True branch.
165 rewriter.moveOpBefore(op, &ifOp.getThenRegion().front(),
166 ifOp.getThenRegion().front().begin());
167 rewriter.setInsertionPointAfter(op);
168 if (op->getNumResults() > 0)
169 rewriter.create<scf::YieldOp>(loc, op->getResults());
170 // False branch.
171 rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
172 SmallVector<Value> elseYieldOperands;
173 elseYieldOperands.reserve(N: ifOp.getNumResults());
174 if (auto viewOp = dyn_cast<memref::SubViewOp>(op)) {
175 // For sub-views, just clone the op.
176 // NOTE: This is okay in the test because we use dynamic memref sizes, so
177 // the verifier will not complain. Otherwise, we may create a logically
178 // out-of-bounds view and a different technique should be used.
179 Operation *opClone = rewriter.clone(op&: *op);
180 elseYieldOperands.append(in_start: opClone->result_begin(), in_end: opClone->result_end());
181 } else {
182 // Default to assuming constant numeric values.
183 for (Type type : op->getResultTypes()) {
184 elseYieldOperands.push_back(rewriter.create<arith::ConstantOp>(
185 loc, rewriter.getZeroAttr(type)));
186 }
187 }
188 if (op->getNumResults() > 0)
189 rewriter.create<scf::YieldOp>(loc, elseYieldOperands);
190 return ifOp.getOperation();
191 }
192
193 static void annotate(Operation *op,
194 mlir::scf::PipeliningOption::PipelinerPart part,
195 unsigned iteration) {
196 OpBuilder b(op);
197 switch (part) {
198 case mlir::scf::PipeliningOption::PipelinerPart::Prologue:
199 op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("prologue"));
200 break;
201 case mlir::scf::PipeliningOption::PipelinerPart::Kernel:
202 op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("kernel"));
203 break;
204 case mlir::scf::PipeliningOption::PipelinerPart::Epilogue:
205 op->setAttr(kTestPipeliningAnnotationPart, b.getStringAttr("epilogue"));
206 break;
207 }
208 op->setAttr(kTestPipeliningAnnotationIteration,
209 b.getI32IntegerAttr(iteration));
210 }
211
212 void getDependentDialects(DialectRegistry &registry) const override {
213 registry.insert<arith::ArithDialect, memref::MemRefDialect>();
214 }
215
216 void runOnOperation() override {
217 RewritePatternSet patterns(&getContext());
218 mlir::scf::PipeliningOption options;
219 options.getScheduleFn = getSchedule;
220 options.supportDynamicLoops = true;
221 options.predicateFn = predicateOp;
222 if (annotatePipeline)
223 options.annotateFn = annotate;
224 if (noEpiloguePeeling) {
225 options.peelEpilogue = false;
226 }
227 scf::populateSCFLoopPipeliningPatterns(patterns, options);
228 (void)applyPatternsGreedily(getOperation(), std::move(patterns));
229 getOperation().walk([](Operation *op) {
230 // Clean up the markers.
231 op->removeAttr(name: kTestPipeliningStageMarker);
232 op->removeAttr(name: kTestPipeliningOpOrderMarker);
233 });
234 }
235};
236} // namespace
237
238namespace mlir {
239namespace test {
240void registerTestSCFUtilsPass() {
241 PassRegistration<TestSCFForUtilsPass>();
242 PassRegistration<TestSCFIfUtilsPass>();
243 PassRegistration<TestSCFPipeliningPass>();
244}
245} // namespace test
246} // namespace mlir
247

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp