1//===- SCCP.cpp - Sparse Conditional Constant Propagation -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This transformation pass performs a sparse conditional constant propagation
10// in MLIR. It identifies values known to be constant, propagates that
11// information throughout the IR, and replaces them. This is done with an
12// optimistic dataflow analysis that assumes that all values are constant until
13// proven otherwise.
14//
15//===----------------------------------------------------------------------===//
16
17#include "mlir/Transforms/Passes.h"
18
19#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
20#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
21#include "mlir/IR/Builders.h"
22#include "mlir/IR/Dialect.h"
23#include "mlir/Interfaces/SideEffectInterfaces.h"
24#include "mlir/Pass/Pass.h"
25#include "mlir/Transforms/FoldUtils.h"
26
27namespace mlir {
28#define GEN_PASS_DEF_SCCP
29#include "mlir/Transforms/Passes.h.inc"
30} // namespace mlir
31
32using namespace mlir;
33using namespace mlir::dataflow;
34
35//===----------------------------------------------------------------------===//
36// SCCP Rewrites
37//===----------------------------------------------------------------------===//
38
39/// Replace the given value with a constant if the corresponding lattice
40/// represents a constant. Returns success if the value was replaced, failure
41/// otherwise.
42static LogicalResult replaceWithConstant(DataFlowSolver &solver,
43 OpBuilder &builder,
44 OperationFolder &folder, Value value) {
45 auto *lattice = solver.lookupState<Lattice<ConstantValue>>(value);
46 if (!lattice || lattice->getValue().isUninitialized())
47 return failure();
48 const ConstantValue &latticeValue = lattice->getValue();
49 if (!latticeValue.getConstantValue())
50 return failure();
51
52 // Attempt to materialize a constant for the given value.
53 Dialect *dialect = latticeValue.getConstantDialect();
54 Value constant = folder.getOrCreateConstant(
55 block: builder.getInsertionBlock(), dialect, value: latticeValue.getConstantValue(),
56 type: value.getType());
57 if (!constant)
58 return failure();
59
60 value.replaceAllUsesWith(newValue: constant);
61 return success();
62}
63
64/// Rewrite the given regions using the computing analysis. This replaces the
65/// uses of all values that have been computed to be constant, and erases as
66/// many newly dead operations.
67static void rewrite(DataFlowSolver &solver, MLIRContext *context,
68 MutableArrayRef<Region> initialRegions) {
69 SmallVector<Block *> worklist;
70 auto addToWorklist = [&](MutableArrayRef<Region> regions) {
71 for (Region &region : regions)
72 for (Block &block : llvm::reverse(C&: region))
73 worklist.push_back(Elt: &block);
74 };
75
76 // An operation folder used to create and unique constants.
77 OperationFolder folder(context);
78 OpBuilder builder(context);
79
80 addToWorklist(initialRegions);
81 while (!worklist.empty()) {
82 Block *block = worklist.pop_back_val();
83
84 for (Operation &op : llvm::make_early_inc_range(Range&: *block)) {
85 builder.setInsertionPoint(&op);
86
87 // Replace any result with constants.
88 bool replacedAll = op.getNumResults() != 0;
89 for (Value res : op.getResults())
90 replacedAll &=
91 succeeded(result: replaceWithConstant(solver, builder, folder, value: res));
92
93 // If all of the results of the operation were replaced, try to erase
94 // the operation completely.
95 if (replacedAll && wouldOpBeTriviallyDead(op: &op)) {
96 assert(op.use_empty() && "expected all uses to be replaced");
97 op.erase();
98 continue;
99 }
100
101 // Add any the regions of this operation to the worklist.
102 addToWorklist(op.getRegions());
103 }
104
105 // Replace any block arguments with constants.
106 builder.setInsertionPointToStart(block);
107 for (BlockArgument arg : block->getArguments())
108 (void)replaceWithConstant(solver, builder, folder, value: arg);
109 }
110}
111
112//===----------------------------------------------------------------------===//
113// SCCP Pass
114//===----------------------------------------------------------------------===//
115
116namespace {
117struct SCCP : public impl::SCCPBase<SCCP> {
118 void runOnOperation() override;
119};
120} // namespace
121
122void SCCP::runOnOperation() {
123 Operation *op = getOperation();
124
125 DataFlowSolver solver;
126 solver.load<DeadCodeAnalysis>();
127 solver.load<SparseConstantPropagation>();
128 if (failed(result: solver.initializeAndRun(top: op)))
129 return signalPassFailure();
130 rewrite(solver, context: op->getContext(), initialRegions: op->getRegions());
131}
132
133std::unique_ptr<Pass> mlir::createSCCPPass() {
134 return std::make_unique<SCCP>();
135}
136

source code of mlir/lib/Transforms/SCCP.cpp