1//===- RotateWhileLoop.cpp - scf.while loop rotation ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Rotates `scf.while` loops.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/SCF/Transforms/Patterns.h"
14
15#include "mlir/Dialect/SCF/IR/SCF.h"
16
17using namespace mlir;
18
19namespace {
20struct RotateWhileLoopPattern : OpRewritePattern<scf::WhileOp> {
21 using OpRewritePattern<scf::WhileOp>::OpRewritePattern;
22
23 LogicalResult matchAndRewrite(scf::WhileOp whileOp,
24 PatternRewriter &rewriter) const final {
25 // Setting this option would lead to infinite recursion on a greedy driver
26 // as 'do-while' loops wouldn't be skipped.
27 constexpr bool forceCreateCheck = false;
28 FailureOr<scf::WhileOp> result =
29 scf::wrapWhileLoopInZeroTripCheck(whileOp, rewriter, forceCreateCheck);
30 // scf::wrapWhileLoopInZeroTripCheck hasn't yet implemented a failure
31 // mechanism. 'do-while' loops are simply returned unmodified. In order to
32 // stop recursion, we check input and output operations differ.
33 return success(succeeded(result) && *result != whileOp);
34 }
35};
36} // namespace
37
38namespace mlir {
39namespace scf {
40void populateSCFRotateWhileLoopPatterns(RewritePatternSet &patterns) {
41 patterns.add<RotateWhileLoopPattern>(arg: patterns.getContext());
42}
43} // namespace scf
44} // namespace mlir
45

source code of mlir/lib/Dialect/SCF/Transforms/RotateWhileLoop.cpp