1 | //===- TestIntRangeInference.cpp - Create consts from range inference ---===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // TODO: This pass is needed to test integer range inference until that |
9 | // functionality has been integrated into SCCP. |
10 | //===----------------------------------------------------------------------===// |
11 | |
12 | #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" |
13 | #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" |
14 | #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" |
15 | #include "mlir/Interfaces/SideEffectInterfaces.h" |
16 | #include "mlir/Pass/Pass.h" |
17 | #include "mlir/Pass/PassRegistry.h" |
18 | #include "mlir/Support/TypeID.h" |
19 | #include "mlir/Transforms/FoldUtils.h" |
20 | #include <optional> |
21 | |
22 | using namespace mlir; |
23 | using namespace mlir::dataflow; |
24 | |
25 | /// Patterned after SCCP |
26 | static LogicalResult replaceWithConstant(DataFlowSolver &solver, OpBuilder &b, |
27 | OperationFolder &folder, Value value) { |
28 | auto *maybeInferredRange = |
29 | solver.lookupState<IntegerValueRangeLattice>(point: value); |
30 | if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized()) |
31 | return failure(); |
32 | const ConstantIntRanges &inferredRange = |
33 | maybeInferredRange->getValue().getValue(); |
34 | std::optional<APInt> maybeConstValue = inferredRange.getConstantValue(); |
35 | if (!maybeConstValue.has_value()) |
36 | return failure(); |
37 | |
38 | Operation *maybeDefiningOp = value.getDefiningOp(); |
39 | Dialect *valueDialect = |
40 | maybeDefiningOp ? maybeDefiningOp->getDialect() |
41 | : value.getParentRegion()->getParentOp()->getDialect(); |
42 | Attribute constAttr = b.getIntegerAttr(value.getType(), *maybeConstValue); |
43 | Value constant = folder.getOrCreateConstant( |
44 | block: b.getInsertionBlock(), dialect: valueDialect, value: constAttr, type: value.getType()); |
45 | if (!constant) |
46 | return failure(); |
47 | |
48 | value.replaceAllUsesWith(newValue: constant); |
49 | return success(); |
50 | } |
51 | |
52 | static void rewrite(DataFlowSolver &solver, MLIRContext *context, |
53 | MutableArrayRef<Region> initialRegions) { |
54 | SmallVector<Block *> worklist; |
55 | auto addToWorklist = [&](MutableArrayRef<Region> regions) { |
56 | for (Region ®ion : regions) |
57 | for (Block &block : llvm::reverse(C&: region)) |
58 | worklist.push_back(Elt: &block); |
59 | }; |
60 | |
61 | OpBuilder builder(context); |
62 | OperationFolder folder(context); |
63 | |
64 | addToWorklist(initialRegions); |
65 | while (!worklist.empty()) { |
66 | Block *block = worklist.pop_back_val(); |
67 | |
68 | for (Operation &op : llvm::make_early_inc_range(Range&: *block)) { |
69 | builder.setInsertionPoint(&op); |
70 | |
71 | // Replace any result with constants. |
72 | bool replacedAll = op.getNumResults() != 0; |
73 | for (Value res : op.getResults()) |
74 | replacedAll &= |
75 | succeeded(result: replaceWithConstant(solver, b&: builder, folder, value: res)); |
76 | |
77 | // If all of the results of the operation were replaced, try to erase |
78 | // the operation completely. |
79 | if (replacedAll && wouldOpBeTriviallyDead(op: &op)) { |
80 | assert(op.use_empty() && "expected all uses to be replaced" ); |
81 | op.erase(); |
82 | continue; |
83 | } |
84 | |
85 | // Add any the regions of this operation to the worklist. |
86 | addToWorklist(op.getRegions()); |
87 | } |
88 | |
89 | // Replace any block arguments with constants. |
90 | builder.setInsertionPointToStart(block); |
91 | for (BlockArgument arg : block->getArguments()) |
92 | (void)replaceWithConstant(solver, b&: builder, folder, value: arg); |
93 | } |
94 | } |
95 | |
96 | namespace { |
97 | struct TestIntRangeInference |
98 | : PassWrapper<TestIntRangeInference, OperationPass<>> { |
99 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestIntRangeInference) |
100 | |
101 | StringRef getArgument() const final { return "test-int-range-inference" ; } |
102 | StringRef getDescription() const final { |
103 | return "Test integer range inference analysis" ; |
104 | } |
105 | |
106 | void runOnOperation() override { |
107 | Operation *op = getOperation(); |
108 | DataFlowSolver solver; |
109 | solver.load<DeadCodeAnalysis>(); |
110 | solver.load<SparseConstantPropagation>(); |
111 | solver.load<IntegerRangeAnalysis>(); |
112 | if (failed(result: solver.initializeAndRun(top: op))) |
113 | return signalPassFailure(); |
114 | rewrite(solver, context: op->getContext(), initialRegions: op->getRegions()); |
115 | } |
116 | }; |
117 | } // end anonymous namespace |
118 | |
119 | namespace mlir { |
120 | namespace test { |
121 | void registerTestIntRangeInference() { |
122 | PassRegistration<TestIntRangeInference>(); |
123 | } |
124 | } // end namespace test |
125 | } // end namespace mlir |
126 | |