1//===- UnsignedWhenEquivalent.cpp - Pass to replace signed operations with
2// unsigned
3// ones when all their arguments and results are statically non-negative --===//
4//
5// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
6// See https://llvm.org/LICENSE.txt for license information.
7// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8//
9//===----------------------------------------------------------------------===//
10
11#include "mlir/Dialect/Arith/Transforms/Passes.h"
12
13#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
14#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/Transforms/DialectConversion.h"
17
18namespace mlir {
19namespace arith {
20#define GEN_PASS_DEF_ARITHUNSIGNEDWHENEQUIVALENT
21#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
22} // namespace arith
23} // namespace mlir
24
25using namespace mlir;
26using namespace mlir::arith;
27using namespace mlir::dataflow;
28
29/// Succeeds when a value is statically non-negative in that it has a lower
30/// bound on its value (if it is treated as signed) and that bound is
31/// non-negative.
32static LogicalResult staticallyNonNegative(DataFlowSolver &solver, Value v) {
33 auto *result = solver.lookupState<IntegerValueRangeLattice>(v);
34 if (!result || result->getValue().isUninitialized())
35 return failure();
36 const ConstantIntRanges &range = result->getValue().getValue();
37 return success(isSuccess: range.smin().isNonNegative());
38}
39
40/// Succeeds if an op can be converted to its unsigned equivalent without
41/// changing its semantics. This is the case when none of its openands or
42/// results can be below 0 when analyzed from a signed perspective.
43static LogicalResult staticallyNonNegative(DataFlowSolver &solver,
44 Operation *op) {
45 auto nonNegativePred = [&solver](Value v) -> bool {
46 return succeeded(result: staticallyNonNegative(solver, v));
47 };
48 return success(isSuccess: llvm::all_of(Range: op->getOperands(), P: nonNegativePred) &&
49 llvm::all_of(Range: op->getResults(), P: nonNegativePred));
50}
51
52/// Succeeds when the comparison predicate is a signed operation and all the
53/// operands are non-negative, indicating that the cmpi operation `op` can have
54/// its predicate changed to an unsigned equivalent.
55static LogicalResult isCmpIConvertable(DataFlowSolver &solver, CmpIOp op) {
56 CmpIPredicate pred = op.getPredicate();
57 switch (pred) {
58 case CmpIPredicate::sle:
59 case CmpIPredicate::slt:
60 case CmpIPredicate::sge:
61 case CmpIPredicate::sgt:
62 return success(llvm::all_of(op.getOperands(), [&solver](Value v) -> bool {
63 return succeeded(result: staticallyNonNegative(solver, v));
64 }));
65 default:
66 return failure();
67 }
68}
69
70/// Return the unsigned equivalent of a signed comparison predicate,
71/// or the predicate itself if there is none.
72static CmpIPredicate toUnsignedPred(CmpIPredicate pred) {
73 switch (pred) {
74 case CmpIPredicate::sle:
75 return CmpIPredicate::ule;
76 case CmpIPredicate::slt:
77 return CmpIPredicate::ult;
78 case CmpIPredicate::sge:
79 return CmpIPredicate::uge;
80 case CmpIPredicate::sgt:
81 return CmpIPredicate::ugt;
82 default:
83 return pred;
84 }
85}
86
87namespace {
88template <typename Signed, typename Unsigned>
89struct ConvertOpToUnsigned : OpConversionPattern<Signed> {
90 using OpConversionPattern<Signed>::OpConversionPattern;
91
92 LogicalResult matchAndRewrite(Signed op, typename Signed::Adaptor adaptor,
93 ConversionPatternRewriter &rw) const override {
94 rw.replaceOpWithNewOp<Unsigned>(op, op->getResultTypes(),
95 adaptor.getOperands(), op->getAttrs());
96 return success();
97 }
98};
99
100struct ConvertCmpIToUnsigned : OpConversionPattern<CmpIOp> {
101 using OpConversionPattern<CmpIOp>::OpConversionPattern;
102
103 LogicalResult matchAndRewrite(CmpIOp op, CmpIOpAdaptor adaptor,
104 ConversionPatternRewriter &rw) const override {
105 rw.replaceOpWithNewOp<CmpIOp>(op, toUnsignedPred(op.getPredicate()),
106 op.getLhs(), op.getRhs());
107 return success();
108 }
109};
110
111struct ArithUnsignedWhenEquivalentPass
112 : public arith::impl::ArithUnsignedWhenEquivalentBase<
113 ArithUnsignedWhenEquivalentPass> {
114 /// Implementation structure: first find all equivalent ops and collect them,
115 /// then perform all the rewrites in a second pass over the target op. This
116 /// ensures that analysis results are not invalidated during rewriting.
117 void runOnOperation() override {
118 Operation *op = getOperation();
119 MLIRContext *ctx = op->getContext();
120 DataFlowSolver solver;
121 solver.load<DeadCodeAnalysis>();
122 solver.load<IntegerRangeAnalysis>();
123 if (failed(result: solver.initializeAndRun(top: op)))
124 return signalPassFailure();
125
126 ConversionTarget target(*ctx);
127 target.addLegalDialect<ArithDialect>();
128 target.addDynamicallyLegalOp<DivSIOp, CeilDivSIOp, FloorDivSIOp, RemSIOp,
129 MinSIOp, MaxSIOp, ExtSIOp>(
130 [&solver](Operation *op) -> std::optional<bool> {
131 return failed(staticallyNonNegative(solver, op));
132 });
133 target.addDynamicallyLegalOp<CmpIOp>(
134 [&solver](CmpIOp op) -> std::optional<bool> {
135 return failed(isCmpIConvertable(solver, op));
136 });
137
138 RewritePatternSet patterns(ctx);
139 patterns.add<ConvertOpToUnsigned<DivSIOp, DivUIOp>,
140 ConvertOpToUnsigned<CeilDivSIOp, CeilDivUIOp>,
141 ConvertOpToUnsigned<FloorDivSIOp, DivUIOp>,
142 ConvertOpToUnsigned<RemSIOp, RemUIOp>,
143 ConvertOpToUnsigned<MinSIOp, MinUIOp>,
144 ConvertOpToUnsigned<MaxSIOp, MaxUIOp>,
145 ConvertOpToUnsigned<ExtSIOp, ExtUIOp>, ConvertCmpIToUnsigned>(
146 ctx);
147
148 if (failed(applyPartialConversion(op, target, std::move(patterns)))) {
149 signalPassFailure();
150 }
151 }
152};
153} // end anonymous namespace
154
155std::unique_ptr<Pass> mlir::arith::createArithUnsignedWhenEquivalentPass() {
156 return std::make_unique<ArithUnsignedWhenEquivalentPass>();
157}
158

source code of mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp