1 | //===- UnsignedWhenEquivalent.cpp - Pass to replace signed operations with |
2 | // unsigned |
3 | // ones when all their arguments and results are statically non-negative --===// |
4 | // |
5 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
6 | // See https://llvm.org/LICENSE.txt for license information. |
7 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
8 | // |
9 | //===----------------------------------------------------------------------===// |
10 | |
11 | #include "mlir/Dialect/Arith/Transforms/Passes.h" |
12 | |
13 | #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" |
14 | #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" |
15 | #include "mlir/Dialect/Arith/IR/Arith.h" |
16 | #include "mlir/Transforms/DialectConversion.h" |
17 | |
18 | namespace mlir { |
19 | namespace arith { |
20 | #define GEN_PASS_DEF_ARITHUNSIGNEDWHENEQUIVALENT |
21 | #include "mlir/Dialect/Arith/Transforms/Passes.h.inc" |
22 | } // namespace arith |
23 | } // namespace mlir |
24 | |
25 | using namespace mlir; |
26 | using namespace mlir::arith; |
27 | using namespace mlir::dataflow; |
28 | |
29 | /// Succeeds when a value is statically non-negative in that it has a lower |
30 | /// bound on its value (if it is treated as signed) and that bound is |
31 | /// non-negative. |
32 | static LogicalResult staticallyNonNegative(DataFlowSolver &solver, Value v) { |
33 | auto *result = solver.lookupState<IntegerValueRangeLattice>(v); |
34 | if (!result || result->getValue().isUninitialized()) |
35 | return failure(); |
36 | const ConstantIntRanges &range = result->getValue().getValue(); |
37 | return success(isSuccess: range.smin().isNonNegative()); |
38 | } |
39 | |
40 | /// Succeeds if an op can be converted to its unsigned equivalent without |
41 | /// changing its semantics. This is the case when none of its openands or |
42 | /// results can be below 0 when analyzed from a signed perspective. |
43 | static LogicalResult staticallyNonNegative(DataFlowSolver &solver, |
44 | Operation *op) { |
45 | auto nonNegativePred = [&solver](Value v) -> bool { |
46 | return succeeded(result: staticallyNonNegative(solver, v)); |
47 | }; |
48 | return success(isSuccess: llvm::all_of(Range: op->getOperands(), P: nonNegativePred) && |
49 | llvm::all_of(Range: op->getResults(), P: nonNegativePred)); |
50 | } |
51 | |
52 | /// Succeeds when the comparison predicate is a signed operation and all the |
53 | /// operands are non-negative, indicating that the cmpi operation `op` can have |
54 | /// its predicate changed to an unsigned equivalent. |
55 | static LogicalResult isCmpIConvertable(DataFlowSolver &solver, CmpIOp op) { |
56 | CmpIPredicate pred = op.getPredicate(); |
57 | switch (pred) { |
58 | case CmpIPredicate::sle: |
59 | case CmpIPredicate::slt: |
60 | case CmpIPredicate::sge: |
61 | case CmpIPredicate::sgt: |
62 | return success(llvm::all_of(op.getOperands(), [&solver](Value v) -> bool { |
63 | return succeeded(result: staticallyNonNegative(solver, v)); |
64 | })); |
65 | default: |
66 | return failure(); |
67 | } |
68 | } |
69 | |
70 | /// Return the unsigned equivalent of a signed comparison predicate, |
71 | /// or the predicate itself if there is none. |
72 | static CmpIPredicate toUnsignedPred(CmpIPredicate pred) { |
73 | switch (pred) { |
74 | case CmpIPredicate::sle: |
75 | return CmpIPredicate::ule; |
76 | case CmpIPredicate::slt: |
77 | return CmpIPredicate::ult; |
78 | case CmpIPredicate::sge: |
79 | return CmpIPredicate::uge; |
80 | case CmpIPredicate::sgt: |
81 | return CmpIPredicate::ugt; |
82 | default: |
83 | return pred; |
84 | } |
85 | } |
86 | |
87 | namespace { |
88 | template <typename Signed, typename Unsigned> |
89 | struct ConvertOpToUnsigned : OpConversionPattern<Signed> { |
90 | using OpConversionPattern<Signed>::OpConversionPattern; |
91 | |
92 | LogicalResult matchAndRewrite(Signed op, typename Signed::Adaptor adaptor, |
93 | ConversionPatternRewriter &rw) const override { |
94 | rw.replaceOpWithNewOp<Unsigned>(op, op->getResultTypes(), |
95 | adaptor.getOperands(), op->getAttrs()); |
96 | return success(); |
97 | } |
98 | }; |
99 | |
100 | struct ConvertCmpIToUnsigned : OpConversionPattern<CmpIOp> { |
101 | using OpConversionPattern<CmpIOp>::OpConversionPattern; |
102 | |
103 | LogicalResult matchAndRewrite(CmpIOp op, CmpIOpAdaptor adaptor, |
104 | ConversionPatternRewriter &rw) const override { |
105 | rw.replaceOpWithNewOp<CmpIOp>(op, toUnsignedPred(op.getPredicate()), |
106 | op.getLhs(), op.getRhs()); |
107 | return success(); |
108 | } |
109 | }; |
110 | |
111 | struct ArithUnsignedWhenEquivalentPass |
112 | : public arith::impl::ArithUnsignedWhenEquivalentBase< |
113 | ArithUnsignedWhenEquivalentPass> { |
114 | /// Implementation structure: first find all equivalent ops and collect them, |
115 | /// then perform all the rewrites in a second pass over the target op. This |
116 | /// ensures that analysis results are not invalidated during rewriting. |
117 | void runOnOperation() override { |
118 | Operation *op = getOperation(); |
119 | MLIRContext *ctx = op->getContext(); |
120 | DataFlowSolver solver; |
121 | solver.load<DeadCodeAnalysis>(); |
122 | solver.load<IntegerRangeAnalysis>(); |
123 | if (failed(result: solver.initializeAndRun(top: op))) |
124 | return signalPassFailure(); |
125 | |
126 | ConversionTarget target(*ctx); |
127 | target.addLegalDialect<ArithDialect>(); |
128 | target.addDynamicallyLegalOp<DivSIOp, CeilDivSIOp, FloorDivSIOp, RemSIOp, |
129 | MinSIOp, MaxSIOp, ExtSIOp>( |
130 | [&solver](Operation *op) -> std::optional<bool> { |
131 | return failed(staticallyNonNegative(solver, op)); |
132 | }); |
133 | target.addDynamicallyLegalOp<CmpIOp>( |
134 | [&solver](CmpIOp op) -> std::optional<bool> { |
135 | return failed(isCmpIConvertable(solver, op)); |
136 | }); |
137 | |
138 | RewritePatternSet patterns(ctx); |
139 | patterns.add<ConvertOpToUnsigned<DivSIOp, DivUIOp>, |
140 | ConvertOpToUnsigned<CeilDivSIOp, CeilDivUIOp>, |
141 | ConvertOpToUnsigned<FloorDivSIOp, DivUIOp>, |
142 | ConvertOpToUnsigned<RemSIOp, RemUIOp>, |
143 | ConvertOpToUnsigned<MinSIOp, MinUIOp>, |
144 | ConvertOpToUnsigned<MaxSIOp, MaxUIOp>, |
145 | ConvertOpToUnsigned<ExtSIOp, ExtUIOp>, ConvertCmpIToUnsigned>( |
146 | ctx); |
147 | |
148 | if (failed(applyPartialConversion(op, target, std::move(patterns)))) { |
149 | signalPassFailure(); |
150 | } |
151 | } |
152 | }; |
153 | } // end anonymous namespace |
154 | |
155 | std::unique_ptr<Pass> mlir::arith::createArithUnsignedWhenEquivalentPass() { |
156 | return std::make_unique<ArithUnsignedWhenEquivalentPass>(); |
157 | } |
158 | |