1//===- ReductionTreePass.cpp - ReductionTreePass Implementation -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the Reduction Tree Pass class. It provides a framework for
10// the implementation of different reduction passes in the MLIR Reduce tool. It
11// allows for custom specification of the variant generation behavior. It
12// implements methods that define the different possible traversals of the
13// reduction tree.
14//
15//===----------------------------------------------------------------------===//
16
17#include "mlir/IR/DialectInterface.h"
18#include "mlir/IR/OpDefinition.h"
19#include "mlir/Reducer/Passes.h"
20#include "mlir/Reducer/ReductionNode.h"
21#include "mlir/Reducer/ReductionPatternInterface.h"
22#include "mlir/Reducer/Tester.h"
23#include "mlir/Rewrite/FrozenRewritePatternSet.h"
24#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25
26#include "llvm/ADT/ArrayRef.h"
27#include "llvm/ADT/SmallVector.h"
28#include "llvm/Support/Allocator.h"
29#include "llvm/Support/ManagedStatic.h"
30
31namespace mlir {
32#define GEN_PASS_DEF_REDUCTIONTREE
33#include "mlir/Reducer/Passes.h.inc"
34} // namespace mlir
35
36using namespace mlir;
37
38/// We implicitly number each operation in the region and if an operation's
39/// number falls into rangeToKeep, we need to keep it and apply the given
40/// rewrite patterns on it.
41static void applyPatterns(Region &region,
42 const FrozenRewritePatternSet &patterns,
43 ArrayRef<ReductionNode::Range> rangeToKeep,
44 bool eraseOpNotInRange) {
45 std::vector<Operation *> opsNotInRange;
46 std::vector<Operation *> opsInRange;
47 size_t keepIndex = 0;
48 for (const auto &op : enumerate(region.getOps())) {
49 int index = op.index();
50 if (keepIndex < rangeToKeep.size() &&
51 index == rangeToKeep[keepIndex].second)
52 ++keepIndex;
53 if (keepIndex == rangeToKeep.size() || index < rangeToKeep[keepIndex].first)
54 opsNotInRange.push_back(&op.value());
55 else
56 opsInRange.push_back(&op.value());
57 }
58
59 // `applyOpPatternsAndFold` may erase the ops so we can't do the pattern
60 // matching in above iteration. Besides, erase op not-in-range may end up in
61 // invalid module, so `applyOpPatternsAndFold` should come before that
62 // transform.
63 for (Operation *op : opsInRange) {
64 // `applyOpPatternsAndFold` returns whether the op is convered. Omit it
65 // because we don't have expectation this reduction will be success or not.
66 GreedyRewriteConfig config;
67 config.strictMode = GreedyRewriteStrictness::ExistingOps;
68 (void)applyOpPatternsAndFold(op, patterns, config);
69 }
70
71 if (eraseOpNotInRange)
72 for (Operation *op : opsNotInRange) {
73 op->dropAllUses();
74 op->erase();
75 }
76}
77
78/// We will apply the reducer patterns to the operations in the ranges specified
79/// by ReductionNode. Note that we are not able to remove an operation without
80/// replacing it with another valid operation. However, The validity of module
81/// reduction is based on the Tester provided by the user and that means certain
82/// invalid module is still interested by the use. Thus we provide an
83/// alternative way to remove operations, which is using `eraseOpNotInRange` to
84/// erase the operations not in the range specified by ReductionNode.
85template <typename IteratorType>
86static LogicalResult findOptimal(ModuleOp module, Region &region,
87 const FrozenRewritePatternSet &patterns,
88 const Tester &test, bool eraseOpNotInRange) {
89 std::pair<Tester::Interestingness, size_t> initStatus =
90 test.isInteresting(module);
91 // While exploring the reduction tree, we always branch from an interesting
92 // node. Thus the root node must be interesting.
93 if (initStatus.first != Tester::Interestingness::True)
94 return module.emitWarning() << "uninterested module will not be reduced";
95
96 llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
97
98 std::vector<ReductionNode::Range> ranges{
99 {0, std::distance(first: region.op_begin(), last: region.op_end())}};
100
101 ReductionNode *root = allocator.Allocate();
102 new (root) ReductionNode(nullptr, ranges, allocator);
103 // Duplicate the module for root node and locate the region in the copy.
104 if (failed(root->initialize(parentModule: module, parentRegion&: region)))
105 llvm_unreachable("unexpected initialization failure");
106 root->update(result: initStatus);
107
108 ReductionNode *smallestNode = root;
109 IteratorType iter(root);
110
111 while (iter != IteratorType::end()) {
112 ReductionNode &currentNode = *iter;
113 Region &curRegion = currentNode.getRegion();
114
115 applyPatterns(region&: curRegion, patterns, rangeToKeep: currentNode.getRanges(),
116 eraseOpNotInRange);
117 currentNode.update(result: test.isInteresting(currentNode.getModule()));
118
119 if (currentNode.isInteresting() == Tester::Interestingness::True &&
120 currentNode.getSize() < smallestNode->getSize())
121 smallestNode = &currentNode;
122
123 ++iter;
124 }
125
126 // At here, we have found an optimal path to reduce the given region. Retrieve
127 // the path and apply the reducer to it.
128 SmallVector<ReductionNode *> trace;
129 ReductionNode *curNode = smallestNode;
130 trace.push_back(Elt: curNode);
131 while (curNode != root) {
132 curNode = curNode->getParent();
133 trace.push_back(Elt: curNode);
134 }
135
136 // Reduce the region through the optimal path.
137 while (!trace.empty()) {
138 ReductionNode *top = trace.pop_back_val();
139 applyPatterns(region, patterns, rangeToKeep: top->getStartRanges(), eraseOpNotInRange);
140 }
141
142 if (test.isInteresting(module).first != Tester::Interestingness::True)
143 llvm::report_fatal_error(reason: "Reduced module is not interesting");
144 if (test.isInteresting(module).second != smallestNode->getSize())
145 llvm::report_fatal_error(
146 reason: "Reduced module doesn't have consistent size with smallestNode");
147 return success();
148}
149
150template <typename IteratorType>
151static LogicalResult findOptimal(ModuleOp module, Region &region,
152 const FrozenRewritePatternSet &patterns,
153 const Tester &test) {
154 // We separate the reduction process into 2 steps, the first one is to erase
155 // redundant operations and the second one is to apply the reducer patterns.
156
157 // In the first phase, we don't apply any patterns so that we only select the
158 // range of operations to keep to the module stay interesting.
159 if (failed(findOptimal<IteratorType>(module, region, /*patterns=*/{}, test,
160 /*eraseOpNotInRange=*/true)))
161 return failure();
162 // In the second phase, we suppose that no operation is redundant, so we try
163 // to rewrite the operation into simpler form.
164 return findOptimal<IteratorType>(module, region, patterns, test,
165 /*eraseOpNotInRange=*/false);
166}
167
168namespace {
169
170//===----------------------------------------------------------------------===//
171// Reduction Pattern Interface Collection
172//===----------------------------------------------------------------------===//
173
174class ReductionPatternInterfaceCollection
175 : public DialectInterfaceCollection<DialectReductionPatternInterface> {
176public:
177 using Base::Base;
178
179 // Collect the reduce patterns defined by each dialect.
180 void populateReductionPatterns(RewritePatternSet &pattern) const {
181 for (const DialectReductionPatternInterface &interface : *this)
182 interface.populateReductionPatterns(patterns&: pattern);
183 }
184};
185
186//===----------------------------------------------------------------------===//
187// ReductionTreePass
188//===----------------------------------------------------------------------===//
189
190/// This class defines the Reduction Tree Pass. It provides a framework to
191/// to implement a reduction pass using a tree structure to keep track of the
192/// generated reduced variants.
193class ReductionTreePass : public impl::ReductionTreeBase<ReductionTreePass> {
194public:
195 ReductionTreePass() = default;
196 ReductionTreePass(const ReductionTreePass &pass) = default;
197
198 LogicalResult initialize(MLIRContext *context) override;
199
200 /// Runs the pass instance in the pass pipeline.
201 void runOnOperation() override;
202
203private:
204 LogicalResult reduceOp(ModuleOp module, Region &region);
205
206 FrozenRewritePatternSet reducerPatterns;
207};
208
209} // namespace
210
211LogicalResult ReductionTreePass::initialize(MLIRContext *context) {
212 RewritePatternSet patterns(context);
213 ReductionPatternInterfaceCollection reducePatternCollection(context);
214 reducePatternCollection.populateReductionPatterns(pattern&: patterns);
215 reducerPatterns = std::move(patterns);
216 return success();
217}
218
219void ReductionTreePass::runOnOperation() {
220 Operation *topOperation = getOperation();
221 while (topOperation->getParentOp() != nullptr)
222 topOperation = topOperation->getParentOp();
223 ModuleOp module = dyn_cast<ModuleOp>(topOperation);
224 if (!module) {
225 emitError(getOperation()->getLoc())
226 << "top-level op must be 'builtin.module'";
227 return signalPassFailure();
228 }
229
230 SmallVector<Operation *, 8> workList;
231 workList.push_back(Elt: getOperation());
232
233 do {
234 Operation *op = workList.pop_back_val();
235
236 for (Region &region : op->getRegions())
237 if (!region.empty())
238 if (failed(reduceOp(module: module, region)))
239 return signalPassFailure();
240
241 for (Region &region : op->getRegions())
242 for (Operation &op : region.getOps())
243 if (op.getNumRegions() != 0)
244 workList.push_back(Elt: &op);
245 } while (!workList.empty());
246}
247
248LogicalResult ReductionTreePass::reduceOp(ModuleOp module, Region &region) {
249 Tester test(testerName, testerArgs);
250 switch (traversalModeId) {
251 case TraversalMode::SinglePath:
252 return findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>(
253 module, region, reducerPatterns, test);
254 default:
255 return module.emitError() << "unsupported traversal mode detected";
256 }
257}
258
259std::unique_ptr<Pass> mlir::createReductionTreePass() {
260 return std::make_unique<ReductionTreePass>();
261}
262

source code of mlir/lib/Reducer/ReductionTreePass.cpp