1//===- TestIntRangeInference.cpp - Create consts from range inference ---===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// TODO: This pass is needed to test integer range inference until that
9// functionality has been integrated into SCCP.
10//===----------------------------------------------------------------------===//
11
12#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
13#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
14#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
15#include "mlir/Interfaces/SideEffectInterfaces.h"
16#include "mlir/Pass/Pass.h"
17#include "mlir/Pass/PassRegistry.h"
18#include "mlir/Support/TypeID.h"
19#include "mlir/Transforms/FoldUtils.h"
20#include <optional>
21
22using namespace mlir;
23using namespace mlir::dataflow;
24
25/// Patterned after SCCP
26static LogicalResult replaceWithConstant(DataFlowSolver &solver, OpBuilder &b,
27 OperationFolder &folder, Value value) {
28 auto *maybeInferredRange =
29 solver.lookupState<IntegerValueRangeLattice>(point: value);
30 if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
31 return failure();
32 const ConstantIntRanges &inferredRange =
33 maybeInferredRange->getValue().getValue();
34 std::optional<APInt> maybeConstValue = inferredRange.getConstantValue();
35 if (!maybeConstValue.has_value())
36 return failure();
37
38 Operation *maybeDefiningOp = value.getDefiningOp();
39 Dialect *valueDialect =
40 maybeDefiningOp ? maybeDefiningOp->getDialect()
41 : value.getParentRegion()->getParentOp()->getDialect();
42 Attribute constAttr = b.getIntegerAttr(value.getType(), *maybeConstValue);
43 Value constant = folder.getOrCreateConstant(
44 block: b.getInsertionBlock(), dialect: valueDialect, value: constAttr, type: value.getType());
45 if (!constant)
46 return failure();
47
48 value.replaceAllUsesWith(newValue: constant);
49 return success();
50}
51
52static void rewrite(DataFlowSolver &solver, MLIRContext *context,
53 MutableArrayRef<Region> initialRegions) {
54 SmallVector<Block *> worklist;
55 auto addToWorklist = [&](MutableArrayRef<Region> regions) {
56 for (Region &region : regions)
57 for (Block &block : llvm::reverse(C&: region))
58 worklist.push_back(Elt: &block);
59 };
60
61 OpBuilder builder(context);
62 OperationFolder folder(context);
63
64 addToWorklist(initialRegions);
65 while (!worklist.empty()) {
66 Block *block = worklist.pop_back_val();
67
68 for (Operation &op : llvm::make_early_inc_range(Range&: *block)) {
69 builder.setInsertionPoint(&op);
70
71 // Replace any result with constants.
72 bool replacedAll = op.getNumResults() != 0;
73 for (Value res : op.getResults())
74 replacedAll &=
75 succeeded(result: replaceWithConstant(solver, b&: builder, folder, value: res));
76
77 // If all of the results of the operation were replaced, try to erase
78 // the operation completely.
79 if (replacedAll && wouldOpBeTriviallyDead(op: &op)) {
80 assert(op.use_empty() && "expected all uses to be replaced");
81 op.erase();
82 continue;
83 }
84
85 // Add any the regions of this operation to the worklist.
86 addToWorklist(op.getRegions());
87 }
88
89 // Replace any block arguments with constants.
90 builder.setInsertionPointToStart(block);
91 for (BlockArgument arg : block->getArguments())
92 (void)replaceWithConstant(solver, b&: builder, folder, value: arg);
93 }
94}
95
96namespace {
97struct TestIntRangeInference
98 : PassWrapper<TestIntRangeInference, OperationPass<>> {
99 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestIntRangeInference)
100
101 StringRef getArgument() const final { return "test-int-range-inference"; }
102 StringRef getDescription() const final {
103 return "Test integer range inference analysis";
104 }
105
106 void runOnOperation() override {
107 Operation *op = getOperation();
108 DataFlowSolver solver;
109 solver.load<DeadCodeAnalysis>();
110 solver.load<SparseConstantPropagation>();
111 solver.load<IntegerRangeAnalysis>();
112 if (failed(result: solver.initializeAndRun(top: op)))
113 return signalPassFailure();
114 rewrite(solver, context: op->getContext(), initialRegions: op->getRegions());
115 }
116};
117} // end anonymous namespace
118
119namespace mlir {
120namespace test {
121void registerTestIntRangeInference() {
122 PassRegistration<TestIntRangeInference>();
123}
124} // end namespace test
125} // end namespace mlir
126

source code of mlir/test/lib/Transforms/TestIntRangeInference.cpp