1//===- IntRangeOptimizations.cpp - Optimizations based on integer ranges --===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <utility>
10
11#include "mlir/Dialect/Arith/Transforms/Passes.h"
12
13#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
14#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
17
18namespace mlir::arith {
19#define GEN_PASS_DEF_ARITHINTRANGEOPTS
20#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
21} // namespace mlir::arith
22
23using namespace mlir;
24using namespace mlir::arith;
25using namespace mlir::dataflow;
26
27/// Returns true if 2 integer ranges have intersection.
28static bool intersects(const ConstantIntRanges &lhs,
29 const ConstantIntRanges &rhs) {
30 return !((lhs.smax().slt(RHS: rhs.smin()) || lhs.smin().sgt(RHS: rhs.smax())) &&
31 (lhs.umax().ult(RHS: rhs.umin()) || lhs.umin().ugt(RHS: rhs.umax())));
32}
33
34static FailureOr<bool> handleEq(ConstantIntRanges lhs, ConstantIntRanges rhs) {
35 if (!intersects(lhs, rhs))
36 return false;
37
38 return failure();
39}
40
41static FailureOr<bool> handleNe(ConstantIntRanges lhs, ConstantIntRanges rhs) {
42 if (!intersects(lhs, rhs))
43 return true;
44
45 return failure();
46}
47
48static FailureOr<bool> handleSlt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
49 if (lhs.smax().slt(RHS: rhs.smin()))
50 return true;
51
52 if (lhs.smin().sge(RHS: rhs.smax()))
53 return false;
54
55 return failure();
56}
57
58static FailureOr<bool> handleSle(ConstantIntRanges lhs, ConstantIntRanges rhs) {
59 if (lhs.smax().sle(RHS: rhs.smin()))
60 return true;
61
62 if (lhs.smin().sgt(RHS: rhs.smax()))
63 return false;
64
65 return failure();
66}
67
68static FailureOr<bool> handleSgt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
69 return handleSlt(lhs: std::move(rhs), rhs: std::move(lhs));
70}
71
72static FailureOr<bool> handleSge(ConstantIntRanges lhs, ConstantIntRanges rhs) {
73 return handleSle(lhs: std::move(rhs), rhs: std::move(lhs));
74}
75
76static FailureOr<bool> handleUlt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
77 if (lhs.umax().ult(RHS: rhs.umin()))
78 return true;
79
80 if (lhs.umin().uge(RHS: rhs.umax()))
81 return false;
82
83 return failure();
84}
85
86static FailureOr<bool> handleUle(ConstantIntRanges lhs, ConstantIntRanges rhs) {
87 if (lhs.umax().ule(RHS: rhs.umin()))
88 return true;
89
90 if (lhs.umin().ugt(RHS: rhs.umax()))
91 return false;
92
93 return failure();
94}
95
96static FailureOr<bool> handleUgt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
97 return handleUlt(lhs: std::move(rhs), rhs: std::move(lhs));
98}
99
100static FailureOr<bool> handleUge(ConstantIntRanges lhs, ConstantIntRanges rhs) {
101 return handleUle(lhs: std::move(rhs), rhs: std::move(lhs));
102}
103
104namespace {
105struct ConvertCmpOp : public OpRewritePattern<arith::CmpIOp> {
106
107 ConvertCmpOp(MLIRContext *context, DataFlowSolver &s)
108 : OpRewritePattern<arith::CmpIOp>(context), solver(s) {}
109
110 LogicalResult matchAndRewrite(arith::CmpIOp op,
111 PatternRewriter &rewriter) const override {
112 auto *lhsResult =
113 solver.lookupState<dataflow::IntegerValueRangeLattice>(op.getLhs());
114 if (!lhsResult || lhsResult->getValue().isUninitialized())
115 return failure();
116
117 auto *rhsResult =
118 solver.lookupState<dataflow::IntegerValueRangeLattice>(op.getRhs());
119 if (!rhsResult || rhsResult->getValue().isUninitialized())
120 return failure();
121
122 using HandlerFunc =
123 FailureOr<bool> (*)(ConstantIntRanges, ConstantIntRanges);
124 std::array<HandlerFunc, arith::getMaxEnumValForCmpIPredicate() + 1>
125 handlers{};
126 using Pred = arith::CmpIPredicate;
127 handlers[static_cast<size_t>(Pred::eq)] = &handleEq;
128 handlers[static_cast<size_t>(Pred::ne)] = &handleNe;
129 handlers[static_cast<size_t>(Pred::slt)] = &handleSlt;
130 handlers[static_cast<size_t>(Pred::sle)] = &handleSle;
131 handlers[static_cast<size_t>(Pred::sgt)] = &handleSgt;
132 handlers[static_cast<size_t>(Pred::sge)] = &handleSge;
133 handlers[static_cast<size_t>(Pred::ult)] = &handleUlt;
134 handlers[static_cast<size_t>(Pred::ule)] = &handleUle;
135 handlers[static_cast<size_t>(Pred::ugt)] = &handleUgt;
136 handlers[static_cast<size_t>(Pred::uge)] = &handleUge;
137
138 HandlerFunc handler = handlers[static_cast<size_t>(op.getPredicate())];
139 if (!handler)
140 return failure();
141
142 ConstantIntRanges lhsValue = lhsResult->getValue().getValue();
143 ConstantIntRanges rhsValue = rhsResult->getValue().getValue();
144 FailureOr<bool> result = handler(lhsValue, rhsValue);
145
146 if (failed(result))
147 return failure();
148
149 rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(
150 op, static_cast<int64_t>(*result), /*width*/ 1);
151 return success();
152 }
153
154private:
155 DataFlowSolver &solver;
156};
157
158struct IntRangeOptimizationsPass
159 : public arith::impl::ArithIntRangeOptsBase<IntRangeOptimizationsPass> {
160
161 void runOnOperation() override {
162 Operation *op = getOperation();
163 MLIRContext *ctx = op->getContext();
164 DataFlowSolver solver;
165 solver.load<DeadCodeAnalysis>();
166 solver.load<IntegerRangeAnalysis>();
167 if (failed(result: solver.initializeAndRun(top: op)))
168 return signalPassFailure();
169
170 RewritePatternSet patterns(ctx);
171 populateIntRangeOptimizationsPatterns(patterns, solver);
172
173 if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
174 signalPassFailure();
175 }
176};
177} // namespace
178
179void mlir::arith::populateIntRangeOptimizationsPatterns(
180 RewritePatternSet &patterns, DataFlowSolver &solver) {
181 patterns.add<ConvertCmpOp>(arg: patterns.getContext(), args&: solver);
182}
183
184std::unique_ptr<Pass> mlir::arith::createIntRangeOptimizationsPass() {
185 return std::make_unique<IntRangeOptimizationsPass>();
186}
187

source code of mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp