1 | //===- IntRangeOptimizations.cpp - Optimizations based on integer ranges --===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include <utility> |
10 | |
11 | #include "mlir/Dialect/Arith/Transforms/Passes.h" |
12 | |
13 | #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" |
14 | #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" |
15 | #include "mlir/Dialect/Arith/IR/Arith.h" |
16 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
17 | |
18 | namespace mlir::arith { |
19 | #define GEN_PASS_DEF_ARITHINTRANGEOPTS |
20 | #include "mlir/Dialect/Arith/Transforms/Passes.h.inc" |
21 | } // namespace mlir::arith |
22 | |
23 | using namespace mlir; |
24 | using namespace mlir::arith; |
25 | using namespace mlir::dataflow; |
26 | |
27 | /// Returns true if 2 integer ranges have intersection. |
28 | static bool intersects(const ConstantIntRanges &lhs, |
29 | const ConstantIntRanges &rhs) { |
30 | return !((lhs.smax().slt(RHS: rhs.smin()) || lhs.smin().sgt(RHS: rhs.smax())) && |
31 | (lhs.umax().ult(RHS: rhs.umin()) || lhs.umin().ugt(RHS: rhs.umax()))); |
32 | } |
33 | |
34 | static FailureOr<bool> handleEq(ConstantIntRanges lhs, ConstantIntRanges rhs) { |
35 | if (!intersects(lhs, rhs)) |
36 | return false; |
37 | |
38 | return failure(); |
39 | } |
40 | |
41 | static FailureOr<bool> handleNe(ConstantIntRanges lhs, ConstantIntRanges rhs) { |
42 | if (!intersects(lhs, rhs)) |
43 | return true; |
44 | |
45 | return failure(); |
46 | } |
47 | |
48 | static FailureOr<bool> handleSlt(ConstantIntRanges lhs, ConstantIntRanges rhs) { |
49 | if (lhs.smax().slt(RHS: rhs.smin())) |
50 | return true; |
51 | |
52 | if (lhs.smin().sge(RHS: rhs.smax())) |
53 | return false; |
54 | |
55 | return failure(); |
56 | } |
57 | |
58 | static FailureOr<bool> handleSle(ConstantIntRanges lhs, ConstantIntRanges rhs) { |
59 | if (lhs.smax().sle(RHS: rhs.smin())) |
60 | return true; |
61 | |
62 | if (lhs.smin().sgt(RHS: rhs.smax())) |
63 | return false; |
64 | |
65 | return failure(); |
66 | } |
67 | |
68 | static FailureOr<bool> handleSgt(ConstantIntRanges lhs, ConstantIntRanges rhs) { |
69 | return handleSlt(lhs: std::move(rhs), rhs: std::move(lhs)); |
70 | } |
71 | |
72 | static FailureOr<bool> handleSge(ConstantIntRanges lhs, ConstantIntRanges rhs) { |
73 | return handleSle(lhs: std::move(rhs), rhs: std::move(lhs)); |
74 | } |
75 | |
76 | static FailureOr<bool> handleUlt(ConstantIntRanges lhs, ConstantIntRanges rhs) { |
77 | if (lhs.umax().ult(RHS: rhs.umin())) |
78 | return true; |
79 | |
80 | if (lhs.umin().uge(RHS: rhs.umax())) |
81 | return false; |
82 | |
83 | return failure(); |
84 | } |
85 | |
86 | static FailureOr<bool> handleUle(ConstantIntRanges lhs, ConstantIntRanges rhs) { |
87 | if (lhs.umax().ule(RHS: rhs.umin())) |
88 | return true; |
89 | |
90 | if (lhs.umin().ugt(RHS: rhs.umax())) |
91 | return false; |
92 | |
93 | return failure(); |
94 | } |
95 | |
96 | static FailureOr<bool> handleUgt(ConstantIntRanges lhs, ConstantIntRanges rhs) { |
97 | return handleUlt(lhs: std::move(rhs), rhs: std::move(lhs)); |
98 | } |
99 | |
100 | static FailureOr<bool> handleUge(ConstantIntRanges lhs, ConstantIntRanges rhs) { |
101 | return handleUle(lhs: std::move(rhs), rhs: std::move(lhs)); |
102 | } |
103 | |
104 | namespace { |
105 | struct ConvertCmpOp : public OpRewritePattern<arith::CmpIOp> { |
106 | |
107 | ConvertCmpOp(MLIRContext *context, DataFlowSolver &s) |
108 | : OpRewritePattern<arith::CmpIOp>(context), solver(s) {} |
109 | |
110 | LogicalResult matchAndRewrite(arith::CmpIOp op, |
111 | PatternRewriter &rewriter) const override { |
112 | auto *lhsResult = |
113 | solver.lookupState<dataflow::IntegerValueRangeLattice>(op.getLhs()); |
114 | if (!lhsResult || lhsResult->getValue().isUninitialized()) |
115 | return failure(); |
116 | |
117 | auto *rhsResult = |
118 | solver.lookupState<dataflow::IntegerValueRangeLattice>(op.getRhs()); |
119 | if (!rhsResult || rhsResult->getValue().isUninitialized()) |
120 | return failure(); |
121 | |
122 | using HandlerFunc = |
123 | FailureOr<bool> (*)(ConstantIntRanges, ConstantIntRanges); |
124 | std::array<HandlerFunc, arith::getMaxEnumValForCmpIPredicate() + 1> |
125 | handlers{}; |
126 | using Pred = arith::CmpIPredicate; |
127 | handlers[static_cast<size_t>(Pred::eq)] = &handleEq; |
128 | handlers[static_cast<size_t>(Pred::ne)] = &handleNe; |
129 | handlers[static_cast<size_t>(Pred::slt)] = &handleSlt; |
130 | handlers[static_cast<size_t>(Pred::sle)] = &handleSle; |
131 | handlers[static_cast<size_t>(Pred::sgt)] = &handleSgt; |
132 | handlers[static_cast<size_t>(Pred::sge)] = &handleSge; |
133 | handlers[static_cast<size_t>(Pred::ult)] = &handleUlt; |
134 | handlers[static_cast<size_t>(Pred::ule)] = &handleUle; |
135 | handlers[static_cast<size_t>(Pred::ugt)] = &handleUgt; |
136 | handlers[static_cast<size_t>(Pred::uge)] = &handleUge; |
137 | |
138 | HandlerFunc handler = handlers[static_cast<size_t>(op.getPredicate())]; |
139 | if (!handler) |
140 | return failure(); |
141 | |
142 | ConstantIntRanges lhsValue = lhsResult->getValue().getValue(); |
143 | ConstantIntRanges rhsValue = rhsResult->getValue().getValue(); |
144 | FailureOr<bool> result = handler(lhsValue, rhsValue); |
145 | |
146 | if (failed(result)) |
147 | return failure(); |
148 | |
149 | rewriter.replaceOpWithNewOp<arith::ConstantIntOp>( |
150 | op, static_cast<int64_t>(*result), /*width*/ 1); |
151 | return success(); |
152 | } |
153 | |
154 | private: |
155 | DataFlowSolver &solver; |
156 | }; |
157 | |
158 | struct IntRangeOptimizationsPass |
159 | : public arith::impl::ArithIntRangeOptsBase<IntRangeOptimizationsPass> { |
160 | |
161 | void runOnOperation() override { |
162 | Operation *op = getOperation(); |
163 | MLIRContext *ctx = op->getContext(); |
164 | DataFlowSolver solver; |
165 | solver.load<DeadCodeAnalysis>(); |
166 | solver.load<IntegerRangeAnalysis>(); |
167 | if (failed(result: solver.initializeAndRun(top: op))) |
168 | return signalPassFailure(); |
169 | |
170 | RewritePatternSet patterns(ctx); |
171 | populateIntRangeOptimizationsPatterns(patterns, solver); |
172 | |
173 | if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) |
174 | signalPassFailure(); |
175 | } |
176 | }; |
177 | } // namespace |
178 | |
179 | void mlir::arith::populateIntRangeOptimizationsPatterns( |
180 | RewritePatternSet &patterns, DataFlowSolver &solver) { |
181 | patterns.add<ConvertCmpOp>(arg: patterns.getContext(), args&: solver); |
182 | } |
183 | |
184 | std::unique_ptr<Pass> mlir::arith::createIntRangeOptimizationsPass() { |
185 | return std::make_unique<IntRangeOptimizationsPass>(); |
186 | } |
187 | |