1//===- TestSCFWrapInZeroTripCheck.cpp -- Pass to test SCF zero-trip-check -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the passes to test wrap-in-zero-trip-check transforms on
10// SCF loop ops.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Func/IR/FuncOps.h"
15#include "mlir/Dialect/SCF/IR/SCF.h"
16#include "mlir/Dialect/SCF/Transforms/Patterns.h"
17#include "mlir/Dialect/SCF/Transforms/Transforms.h"
18#include "mlir/IR/PatternMatch.h"
19#include "mlir/Pass/Pass.h"
20#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21
22using namespace mlir;
23
24namespace {
25
26struct TestWrapWhileLoopInZeroTripCheckPass
27 : public PassWrapper<TestWrapWhileLoopInZeroTripCheckPass,
28 OperationPass<func::FuncOp>> {
29 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
30 TestWrapWhileLoopInZeroTripCheckPass)
31
32 StringRef getArgument() const final {
33 return "test-wrap-scf-while-loop-in-zero-trip-check";
34 }
35
36 StringRef getDescription() const final {
37 return "test scf::wrapWhileLoopInZeroTripCheck";
38 }
39
40 TestWrapWhileLoopInZeroTripCheckPass() = default;
41 TestWrapWhileLoopInZeroTripCheckPass(
42 const TestWrapWhileLoopInZeroTripCheckPass &) {}
43 explicit TestWrapWhileLoopInZeroTripCheckPass(bool forceCreateCheckParam) {
44 forceCreateCheck = forceCreateCheckParam;
45 }
46
47 void runOnOperation() override {
48 func::FuncOp func = getOperation();
49 MLIRContext *context = &getContext();
50 IRRewriter rewriter(context);
51 if (forceCreateCheck) {
52 func.walk([&](scf::WhileOp op) {
53 FailureOr<scf::WhileOp> result =
54 scf::wrapWhileLoopInZeroTripCheck(op, rewriter, forceCreateCheck);
55 // Ignore not implemented failure in tests. The expected output should
56 // catch problems (e.g. transformation doesn't happen).
57 (void)result;
58 });
59 } else {
60 RewritePatternSet patterns(context);
61 scf::populateSCFRotateWhileLoopPatterns(patterns);
62 (void)applyPatternsGreedily(func, std::move(patterns));
63 }
64 }
65
66 Option<bool> forceCreateCheck{
67 *this, "force-create-check",
68 llvm::cl::desc("Force to create zero-trip-check."),
69 llvm::cl::init(false)};
70};
71
72} // namespace
73
74namespace mlir {
75namespace test {
76void registerTestSCFWrapInZeroTripCheckPasses() {
77 PassRegistration<TestWrapWhileLoopInZeroTripCheckPass>();
78}
79} // namespace test
80} // namespace mlir
81

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/test/lib/Dialect/SCF/TestSCFWrapInZeroTripCheck.cpp