1//===- UnsignedWhenEquivalent.cpp - Pass to replace signed operations with
2// unsigned
3// ones when all their arguments and results are statically non-negative --===//
4//
5// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
6// See https://llvm.org/LICENSE.txt for license information.
7// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8//
9//===----------------------------------------------------------------------===//
10
11#include "mlir/Dialect/Arith/Transforms/Passes.h"
12
13#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
14#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
15#include "mlir/Dialect/Arith/IR/Arith.h"
16#include "mlir/IR/PatternMatch.h"
17#include "mlir/Transforms/WalkPatternRewriteDriver.h"
18
19namespace mlir {
20namespace arith {
21#define GEN_PASS_DEF_ARITHUNSIGNEDWHENEQUIVALENTPASS
22#include "mlir/Dialect/Arith/Transforms/Passes.h.inc"
23} // namespace arith
24} // namespace mlir
25
26using namespace mlir;
27using namespace mlir::arith;
28using namespace mlir::dataflow;
29
30/// Succeeds when the comparison predicate is a signed operation and all the
31/// operands are non-negative, indicating that the cmpi operation `op` can have
32/// its predicate changed to an unsigned equivalent.
33static LogicalResult isCmpIConvertable(DataFlowSolver &solver, CmpIOp op) {
34 CmpIPredicate pred = op.getPredicate();
35 switch (pred) {
36 case CmpIPredicate::sle:
37 case CmpIPredicate::slt:
38 case CmpIPredicate::sge:
39 case CmpIPredicate::sgt:
40 return success(llvm::all_of(op.getOperands(), [&solver](Value v) -> bool {
41 return succeeded(Result: staticallyNonNegative(solver, v));
42 }));
43 default:
44 return failure();
45 }
46}
47
48/// Return the unsigned equivalent of a signed comparison predicate,
49/// or the predicate itself if there is none.
50static CmpIPredicate toUnsignedPred(CmpIPredicate pred) {
51 switch (pred) {
52 case CmpIPredicate::sle:
53 return CmpIPredicate::ule;
54 case CmpIPredicate::slt:
55 return CmpIPredicate::ult;
56 case CmpIPredicate::sge:
57 return CmpIPredicate::uge;
58 case CmpIPredicate::sgt:
59 return CmpIPredicate::ugt;
60 default:
61 return pred;
62 }
63}
64
65namespace {
66class DataFlowListener : public RewriterBase::Listener {
67public:
68 DataFlowListener(DataFlowSolver &s) : s(s) {}
69
70protected:
71 void notifyOperationErased(Operation *op) override {
72 s.eraseState(anchor: s.getProgramPointAfter(op));
73 for (Value res : op->getResults())
74 s.eraseState(anchor: res);
75 }
76
77 DataFlowSolver &s;
78};
79
80// TODO: IntegerRangeAnalysis internally assumes index is 64bit and this pattern
81// (via staticallyNonNegative) relies on this. These transformations may not be
82// valid for 32bit index, need more investigation.
83
84template <typename Signed, typename Unsigned>
85struct ConvertOpToUnsigned final : OpRewritePattern<Signed> {
86 ConvertOpToUnsigned(MLIRContext *context, DataFlowSolver &s)
87 : OpRewritePattern<Signed>(context), solver(s) {}
88
89 LogicalResult matchAndRewrite(Signed op, PatternRewriter &rw) const override {
90 if (failed(
91 staticallyNonNegative(this->solver, static_cast<Operation *>(op))))
92 return failure();
93
94 rw.replaceOpWithNewOp<Unsigned>(op, op->getResultTypes(), op->getOperands(),
95 op->getAttrs());
96 return success();
97 }
98
99private:
100 DataFlowSolver &solver;
101};
102
103struct ConvertCmpIToUnsigned final : OpRewritePattern<CmpIOp> {
104 ConvertCmpIToUnsigned(MLIRContext *context, DataFlowSolver &s)
105 : OpRewritePattern<CmpIOp>(context), solver(s) {}
106
107 LogicalResult matchAndRewrite(CmpIOp op, PatternRewriter &rw) const override {
108 if (failed(isCmpIConvertable(this->solver, op)))
109 return failure();
110
111 rw.replaceOpWithNewOp<CmpIOp>(op, toUnsignedPred(op.getPredicate()),
112 op.getLhs(), op.getRhs());
113 return success();
114 }
115
116private:
117 DataFlowSolver &solver;
118};
119
120struct ArithUnsignedWhenEquivalentPass
121 : public arith::impl::ArithUnsignedWhenEquivalentPassBase<
122 ArithUnsignedWhenEquivalentPass> {
123
124 void runOnOperation() override {
125 Operation *op = getOperation();
126 MLIRContext *ctx = op->getContext();
127 DataFlowSolver solver;
128 solver.load<DeadCodeAnalysis>();
129 solver.load<IntegerRangeAnalysis>();
130 if (failed(Result: solver.initializeAndRun(top: op)))
131 return signalPassFailure();
132
133 DataFlowListener listener(solver);
134
135 RewritePatternSet patterns(ctx);
136 populateUnsignedWhenEquivalentPatterns(patterns, solver);
137
138 walkAndApplyPatterns(op, std::move(patterns), &listener);
139 }
140};
141} // end anonymous namespace
142
143void mlir::arith::populateUnsignedWhenEquivalentPatterns(
144 RewritePatternSet &patterns, DataFlowSolver &solver) {
145 patterns.add<ConvertOpToUnsigned<DivSIOp, DivUIOp>,
146 ConvertOpToUnsigned<CeilDivSIOp, CeilDivUIOp>,
147 ConvertOpToUnsigned<FloorDivSIOp, DivUIOp>,
148 ConvertOpToUnsigned<RemSIOp, RemUIOp>,
149 ConvertOpToUnsigned<MinSIOp, MinUIOp>,
150 ConvertOpToUnsigned<MaxSIOp, MaxUIOp>,
151 ConvertOpToUnsigned<ExtSIOp, ExtUIOp>, ConvertCmpIToUnsigned>(
152 patterns.getContext(), solver);
153}
154

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Dialect/Arith/Transforms/UnsignedWhenEquivalent.cpp