1 | //===- ReductionTreePass.cpp - ReductionTreePass Implementation -----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines the Reduction Tree Pass class. It provides a framework for |
10 | // the implementation of different reduction passes in the MLIR Reduce tool. It |
11 | // allows for custom specification of the variant generation behavior. It |
12 | // implements methods that define the different possible traversals of the |
13 | // reduction tree. |
14 | // |
15 | //===----------------------------------------------------------------------===// |
16 | |
17 | #include "mlir/IR/DialectInterface.h" |
18 | #include "mlir/IR/OpDefinition.h" |
19 | #include "mlir/Reducer/Passes.h" |
20 | #include "mlir/Reducer/ReductionNode.h" |
21 | #include "mlir/Reducer/ReductionPatternInterface.h" |
22 | #include "mlir/Reducer/Tester.h" |
23 | #include "mlir/Rewrite/FrozenRewritePatternSet.h" |
24 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
25 | |
26 | #include "llvm/ADT/ArrayRef.h" |
27 | #include "llvm/ADT/SmallVector.h" |
28 | #include "llvm/Support/Allocator.h" |
29 | #include "llvm/Support/ManagedStatic.h" |
30 | |
31 | namespace mlir { |
32 | #define GEN_PASS_DEF_REDUCTIONTREE |
33 | #include "mlir/Reducer/Passes.h.inc" |
34 | } // namespace mlir |
35 | |
36 | using namespace mlir; |
37 | |
38 | /// We implicitly number each operation in the region and if an operation's |
39 | /// number falls into rangeToKeep, we need to keep it and apply the given |
40 | /// rewrite patterns on it. |
41 | static void applyPatterns(Region ®ion, |
42 | const FrozenRewritePatternSet &patterns, |
43 | ArrayRef<ReductionNode::Range> rangeToKeep, |
44 | bool eraseOpNotInRange) { |
45 | std::vector<Operation *> opsNotInRange; |
46 | std::vector<Operation *> opsInRange; |
47 | size_t keepIndex = 0; |
48 | for (const auto &op : enumerate(region.getOps())) { |
49 | int index = op.index(); |
50 | if (keepIndex < rangeToKeep.size() && |
51 | index == rangeToKeep[keepIndex].second) |
52 | ++keepIndex; |
53 | if (keepIndex == rangeToKeep.size() || index < rangeToKeep[keepIndex].first) |
54 | opsNotInRange.push_back(&op.value()); |
55 | else |
56 | opsInRange.push_back(&op.value()); |
57 | } |
58 | |
59 | // `applyOpPatternsAndFold` may erase the ops so we can't do the pattern |
60 | // matching in above iteration. Besides, erase op not-in-range may end up in |
61 | // invalid module, so `applyOpPatternsAndFold` should come before that |
62 | // transform. |
63 | for (Operation *op : opsInRange) { |
64 | // `applyOpPatternsAndFold` returns whether the op is convered. Omit it |
65 | // because we don't have expectation this reduction will be success or not. |
66 | GreedyRewriteConfig config; |
67 | config.strictMode = GreedyRewriteStrictness::ExistingOps; |
68 | (void)applyOpPatternsAndFold(op, patterns, config); |
69 | } |
70 | |
71 | if (eraseOpNotInRange) |
72 | for (Operation *op : opsNotInRange) { |
73 | op->dropAllUses(); |
74 | op->erase(); |
75 | } |
76 | } |
77 | |
78 | /// We will apply the reducer patterns to the operations in the ranges specified |
79 | /// by ReductionNode. Note that we are not able to remove an operation without |
80 | /// replacing it with another valid operation. However, The validity of module |
81 | /// reduction is based on the Tester provided by the user and that means certain |
82 | /// invalid module is still interested by the use. Thus we provide an |
83 | /// alternative way to remove operations, which is using `eraseOpNotInRange` to |
84 | /// erase the operations not in the range specified by ReductionNode. |
85 | template <typename IteratorType> |
86 | static LogicalResult findOptimal(ModuleOp module, Region ®ion, |
87 | const FrozenRewritePatternSet &patterns, |
88 | const Tester &test, bool eraseOpNotInRange) { |
89 | std::pair<Tester::Interestingness, size_t> initStatus = |
90 | test.isInteresting(module); |
91 | // While exploring the reduction tree, we always branch from an interesting |
92 | // node. Thus the root node must be interesting. |
93 | if (initStatus.first != Tester::Interestingness::True) |
94 | return module.emitWarning() << "uninterested module will not be reduced" ; |
95 | |
96 | llvm::SpecificBumpPtrAllocator<ReductionNode> allocator; |
97 | |
98 | std::vector<ReductionNode::Range> ranges{ |
99 | {0, std::distance(first: region.op_begin(), last: region.op_end())}}; |
100 | |
101 | ReductionNode *root = allocator.Allocate(); |
102 | new (root) ReductionNode(nullptr, ranges, allocator); |
103 | // Duplicate the module for root node and locate the region in the copy. |
104 | if (failed(root->initialize(parentModule: module, parentRegion&: region))) |
105 | llvm_unreachable("unexpected initialization failure" ); |
106 | root->update(result: initStatus); |
107 | |
108 | ReductionNode *smallestNode = root; |
109 | IteratorType iter(root); |
110 | |
111 | while (iter != IteratorType::end()) { |
112 | ReductionNode ¤tNode = *iter; |
113 | Region &curRegion = currentNode.getRegion(); |
114 | |
115 | applyPatterns(region&: curRegion, patterns, rangeToKeep: currentNode.getRanges(), |
116 | eraseOpNotInRange); |
117 | currentNode.update(result: test.isInteresting(currentNode.getModule())); |
118 | |
119 | if (currentNode.isInteresting() == Tester::Interestingness::True && |
120 | currentNode.getSize() < smallestNode->getSize()) |
121 | smallestNode = ¤tNode; |
122 | |
123 | ++iter; |
124 | } |
125 | |
126 | // At here, we have found an optimal path to reduce the given region. Retrieve |
127 | // the path and apply the reducer to it. |
128 | SmallVector<ReductionNode *> trace; |
129 | ReductionNode *curNode = smallestNode; |
130 | trace.push_back(Elt: curNode); |
131 | while (curNode != root) { |
132 | curNode = curNode->getParent(); |
133 | trace.push_back(Elt: curNode); |
134 | } |
135 | |
136 | // Reduce the region through the optimal path. |
137 | while (!trace.empty()) { |
138 | ReductionNode *top = trace.pop_back_val(); |
139 | applyPatterns(region, patterns, rangeToKeep: top->getStartRanges(), eraseOpNotInRange); |
140 | } |
141 | |
142 | if (test.isInteresting(module).first != Tester::Interestingness::True) |
143 | llvm::report_fatal_error(reason: "Reduced module is not interesting" ); |
144 | if (test.isInteresting(module).second != smallestNode->getSize()) |
145 | llvm::report_fatal_error( |
146 | reason: "Reduced module doesn't have consistent size with smallestNode" ); |
147 | return success(); |
148 | } |
149 | |
150 | template <typename IteratorType> |
151 | static LogicalResult findOptimal(ModuleOp module, Region ®ion, |
152 | const FrozenRewritePatternSet &patterns, |
153 | const Tester &test) { |
154 | // We separate the reduction process into 2 steps, the first one is to erase |
155 | // redundant operations and the second one is to apply the reducer patterns. |
156 | |
157 | // In the first phase, we don't apply any patterns so that we only select the |
158 | // range of operations to keep to the module stay interesting. |
159 | if (failed(findOptimal<IteratorType>(module, region, /*patterns=*/{}, test, |
160 | /*eraseOpNotInRange=*/true))) |
161 | return failure(); |
162 | // In the second phase, we suppose that no operation is redundant, so we try |
163 | // to rewrite the operation into simpler form. |
164 | return findOptimal<IteratorType>(module, region, patterns, test, |
165 | /*eraseOpNotInRange=*/false); |
166 | } |
167 | |
168 | namespace { |
169 | |
170 | //===----------------------------------------------------------------------===// |
171 | // Reduction Pattern Interface Collection |
172 | //===----------------------------------------------------------------------===// |
173 | |
174 | class ReductionPatternInterfaceCollection |
175 | : public DialectInterfaceCollection<DialectReductionPatternInterface> { |
176 | public: |
177 | using Base::Base; |
178 | |
179 | // Collect the reduce patterns defined by each dialect. |
180 | void populateReductionPatterns(RewritePatternSet &pattern) const { |
181 | for (const DialectReductionPatternInterface &interface : *this) |
182 | interface.populateReductionPatterns(patterns&: pattern); |
183 | } |
184 | }; |
185 | |
186 | //===----------------------------------------------------------------------===// |
187 | // ReductionTreePass |
188 | //===----------------------------------------------------------------------===// |
189 | |
190 | /// This class defines the Reduction Tree Pass. It provides a framework to |
191 | /// to implement a reduction pass using a tree structure to keep track of the |
192 | /// generated reduced variants. |
193 | class ReductionTreePass : public impl::ReductionTreeBase<ReductionTreePass> { |
194 | public: |
195 | ReductionTreePass() = default; |
196 | ReductionTreePass(const ReductionTreePass &pass) = default; |
197 | |
198 | LogicalResult initialize(MLIRContext *context) override; |
199 | |
200 | /// Runs the pass instance in the pass pipeline. |
201 | void runOnOperation() override; |
202 | |
203 | private: |
204 | LogicalResult reduceOp(ModuleOp module, Region ®ion); |
205 | |
206 | FrozenRewritePatternSet reducerPatterns; |
207 | }; |
208 | |
209 | } // namespace |
210 | |
211 | LogicalResult ReductionTreePass::initialize(MLIRContext *context) { |
212 | RewritePatternSet patterns(context); |
213 | ReductionPatternInterfaceCollection reducePatternCollection(context); |
214 | reducePatternCollection.populateReductionPatterns(pattern&: patterns); |
215 | reducerPatterns = std::move(patterns); |
216 | return success(); |
217 | } |
218 | |
219 | void ReductionTreePass::runOnOperation() { |
220 | Operation *topOperation = getOperation(); |
221 | while (topOperation->getParentOp() != nullptr) |
222 | topOperation = topOperation->getParentOp(); |
223 | ModuleOp module = dyn_cast<ModuleOp>(topOperation); |
224 | if (!module) { |
225 | emitError(getOperation()->getLoc()) |
226 | << "top-level op must be 'builtin.module'" ; |
227 | return signalPassFailure(); |
228 | } |
229 | |
230 | SmallVector<Operation *, 8> workList; |
231 | workList.push_back(Elt: getOperation()); |
232 | |
233 | do { |
234 | Operation *op = workList.pop_back_val(); |
235 | |
236 | for (Region ®ion : op->getRegions()) |
237 | if (!region.empty()) |
238 | if (failed(reduceOp(module: module, region))) |
239 | return signalPassFailure(); |
240 | |
241 | for (Region ®ion : op->getRegions()) |
242 | for (Operation &op : region.getOps()) |
243 | if (op.getNumRegions() != 0) |
244 | workList.push_back(Elt: &op); |
245 | } while (!workList.empty()); |
246 | } |
247 | |
248 | LogicalResult ReductionTreePass::reduceOp(ModuleOp module, Region ®ion) { |
249 | Tester test(testerName, testerArgs); |
250 | switch (traversalModeId) { |
251 | case TraversalMode::SinglePath: |
252 | return findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>( |
253 | module, region, reducerPatterns, test); |
254 | default: |
255 | return module.emitError() << "unsupported traversal mode detected" ; |
256 | } |
257 | } |
258 | |
259 | std::unique_ptr<Pass> mlir::createReductionTreePass() { |
260 | return std::make_unique<ReductionTreePass>(); |
261 | } |
262 | |