1 | //===- ReductionTreePass.cpp - ReductionTreePass Implementation -----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines the Reduction Tree Pass class. It provides a framework for |
10 | // the implementation of different reduction passes in the MLIR Reduce tool. It |
11 | // allows for custom specification of the variant generation behavior. It |
12 | // implements methods that define the different possible traversals of the |
13 | // reduction tree. |
14 | // |
15 | //===----------------------------------------------------------------------===// |
16 | |
17 | #include "mlir/IR/DialectInterface.h" |
18 | #include "mlir/IR/OpDefinition.h" |
19 | #include "mlir/Reducer/Passes.h" |
20 | #include "mlir/Reducer/ReductionNode.h" |
21 | #include "mlir/Reducer/ReductionPatternInterface.h" |
22 | #include "mlir/Reducer/Tester.h" |
23 | #include "mlir/Rewrite/FrozenRewritePatternSet.h" |
24 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
25 | |
26 | #include "llvm/ADT/ArrayRef.h" |
27 | #include "llvm/ADT/SmallVector.h" |
28 | #include "llvm/Support/Allocator.h" |
29 | #include "llvm/Support/ManagedStatic.h" |
30 | |
31 | namespace mlir { |
32 | #define GEN_PASS_DEF_REDUCTIONTREEPASS |
33 | #include "mlir/Reducer/Passes.h.inc" |
34 | } // namespace mlir |
35 | |
36 | using namespace mlir; |
37 | |
38 | /// We implicitly number each operation in the region and if an operation's |
39 | /// number falls into rangeToKeep, we need to keep it and apply the given |
40 | /// rewrite patterns on it. |
41 | static void applyPatterns(Region ®ion, |
42 | const FrozenRewritePatternSet &patterns, |
43 | ArrayRef<ReductionNode::Range> rangeToKeep, |
44 | bool eraseOpNotInRange) { |
45 | std::vector<Operation *> opsNotInRange; |
46 | std::vector<Operation *> opsInRange; |
47 | size_t keepIndex = 0; |
48 | for (const auto &op : enumerate(region.getOps())) { |
49 | int index = op.index(); |
50 | if (keepIndex < rangeToKeep.size() && |
51 | index == rangeToKeep[keepIndex].second) |
52 | ++keepIndex; |
53 | if (keepIndex == rangeToKeep.size() || index < rangeToKeep[keepIndex].first) |
54 | opsNotInRange.push_back(&op.value()); |
55 | else |
56 | opsInRange.push_back(&op.value()); |
57 | } |
58 | |
59 | // `applyOpPatternsGreedily` with folding may erase the ops so we can't do the |
60 | // pattern matching in above iteration. Besides, erase op not-in-range may end |
61 | // up in invalid module, so `applyOpPatternsGreedily` with folding should come |
62 | // before that transform. |
63 | for (Operation *op : opsInRange) { |
64 | // `applyOpPatternsGreedily` with folding returns whether the op is |
65 | // converted. Omit it because we don't have expectation this reduction will |
66 | // be success or not. |
67 | (void)applyOpPatternsGreedily(op, patterns, |
68 | GreedyRewriteConfig().setStrictness( |
69 | GreedyRewriteStrictness::ExistingOps)); |
70 | } |
71 | |
72 | if (eraseOpNotInRange) |
73 | for (Operation *op : opsNotInRange) { |
74 | op->dropAllUses(); |
75 | op->erase(); |
76 | } |
77 | } |
78 | |
79 | /// We will apply the reducer patterns to the operations in the ranges specified |
80 | /// by ReductionNode. Note that we are not able to remove an operation without |
81 | /// replacing it with another valid operation. However, The validity of module |
82 | /// reduction is based on the Tester provided by the user and that means certain |
83 | /// invalid module is still interested by the use. Thus we provide an |
84 | /// alternative way to remove operations, which is using `eraseOpNotInRange` to |
85 | /// erase the operations not in the range specified by ReductionNode. |
86 | template <typename IteratorType> |
87 | static LogicalResult findOptimal(ModuleOp module, Region ®ion, |
88 | const FrozenRewritePatternSet &patterns, |
89 | const Tester &test, bool eraseOpNotInRange) { |
90 | std::pair<Tester::Interestingness, size_t> initStatus = |
91 | test.isInteresting(module); |
92 | // While exploring the reduction tree, we always branch from an interesting |
93 | // node. Thus the root node must be interesting. |
94 | if (initStatus.first != Tester::Interestingness::True) |
95 | return module.emitWarning() << "uninterested module will not be reduced" ; |
96 | |
97 | llvm::SpecificBumpPtrAllocator<ReductionNode> allocator; |
98 | |
99 | std::vector<ReductionNode::Range> ranges{ |
100 | {0, std::distance(first: region.op_begin(), last: region.op_end())}}; |
101 | |
102 | ReductionNode *root = allocator.Allocate(); |
103 | new (root) ReductionNode(nullptr, ranges, allocator); |
104 | // Duplicate the module for root node and locate the region in the copy. |
105 | if (failed(root->initialize(parentModule: module, parentRegion&: region))) |
106 | llvm_unreachable("unexpected initialization failure" ); |
107 | root->update(result: initStatus); |
108 | |
109 | ReductionNode *smallestNode = root; |
110 | IteratorType iter(root); |
111 | |
112 | while (iter != IteratorType::end()) { |
113 | ReductionNode ¤tNode = *iter; |
114 | Region &curRegion = currentNode.getRegion(); |
115 | |
116 | applyPatterns(region&: curRegion, patterns, rangeToKeep: currentNode.getRanges(), |
117 | eraseOpNotInRange); |
118 | currentNode.update(result: test.isInteresting(currentNode.getModule())); |
119 | |
120 | if (currentNode.isInteresting() == Tester::Interestingness::True && |
121 | currentNode.getSize() < smallestNode->getSize()) |
122 | smallestNode = ¤tNode; |
123 | |
124 | ++iter; |
125 | } |
126 | |
127 | // At here, we have found an optimal path to reduce the given region. Retrieve |
128 | // the path and apply the reducer to it. |
129 | SmallVector<ReductionNode *> trace; |
130 | ReductionNode *curNode = smallestNode; |
131 | trace.push_back(Elt: curNode); |
132 | while (curNode != root) { |
133 | curNode = curNode->getParent(); |
134 | trace.push_back(Elt: curNode); |
135 | } |
136 | |
137 | // Reduce the region through the optimal path. |
138 | while (!trace.empty()) { |
139 | ReductionNode *top = trace.pop_back_val(); |
140 | applyPatterns(region, patterns, rangeToKeep: top->getStartRanges(), eraseOpNotInRange); |
141 | } |
142 | |
143 | if (test.isInteresting(module).first != Tester::Interestingness::True) |
144 | llvm::report_fatal_error(reason: "Reduced module is not interesting" ); |
145 | if (test.isInteresting(module).second != smallestNode->getSize()) |
146 | llvm::report_fatal_error( |
147 | reason: "Reduced module doesn't have consistent size with smallestNode" ); |
148 | return success(); |
149 | } |
150 | |
151 | template <typename IteratorType> |
152 | static LogicalResult findOptimal(ModuleOp module, Region ®ion, |
153 | const FrozenRewritePatternSet &patterns, |
154 | const Tester &test) { |
155 | // We separate the reduction process into 2 steps, the first one is to erase |
156 | // redundant operations and the second one is to apply the reducer patterns. |
157 | |
158 | // In the first phase, we don't apply any patterns so that we only select the |
159 | // range of operations to keep to the module stay interesting. |
160 | if (failed(findOptimal<IteratorType>(module, region, /*patterns=*/{}, test, |
161 | /*eraseOpNotInRange=*/true))) |
162 | return failure(); |
163 | // In the second phase, we suppose that no operation is redundant, so we try |
164 | // to rewrite the operation into simpler form. |
165 | return findOptimal<IteratorType>(module, region, patterns, test, |
166 | /*eraseOpNotInRange=*/false); |
167 | } |
168 | |
169 | namespace { |
170 | |
171 | //===----------------------------------------------------------------------===// |
172 | // Reduction Pattern Interface Collection |
173 | //===----------------------------------------------------------------------===// |
174 | |
175 | class ReductionPatternInterfaceCollection |
176 | : public DialectInterfaceCollection<DialectReductionPatternInterface> { |
177 | public: |
178 | using Base::Base; |
179 | |
180 | // Collect the reduce patterns defined by each dialect. |
181 | void populateReductionPatterns(RewritePatternSet &pattern) const { |
182 | for (const DialectReductionPatternInterface &interface : *this) |
183 | interface.populateReductionPatterns(patterns&: pattern); |
184 | } |
185 | }; |
186 | |
187 | //===----------------------------------------------------------------------===// |
188 | // ReductionTreePass |
189 | //===----------------------------------------------------------------------===// |
190 | |
191 | /// This class defines the Reduction Tree Pass. It provides a framework to |
192 | /// to implement a reduction pass using a tree structure to keep track of the |
193 | /// generated reduced variants. |
194 | class ReductionTreePass |
195 | : public impl::ReductionTreePassBase<ReductionTreePass> { |
196 | public: |
197 | using Base::Base; |
198 | |
199 | LogicalResult initialize(MLIRContext *context) override; |
200 | |
201 | /// Runs the pass instance in the pass pipeline. |
202 | void runOnOperation() override; |
203 | |
204 | private: |
205 | LogicalResult reduceOp(ModuleOp module, Region ®ion); |
206 | |
207 | FrozenRewritePatternSet reducerPatterns; |
208 | }; |
209 | |
210 | } // namespace |
211 | |
212 | LogicalResult ReductionTreePass::initialize(MLIRContext *context) { |
213 | RewritePatternSet patterns(context); |
214 | ReductionPatternInterfaceCollection reducePatternCollection(context); |
215 | reducePatternCollection.populateReductionPatterns(pattern&: patterns); |
216 | reducerPatterns = std::move(patterns); |
217 | return success(); |
218 | } |
219 | |
220 | void ReductionTreePass::runOnOperation() { |
221 | Operation *topOperation = getOperation(); |
222 | while (topOperation->getParentOp() != nullptr) |
223 | topOperation = topOperation->getParentOp(); |
224 | ModuleOp module = dyn_cast<ModuleOp>(topOperation); |
225 | if (!module) { |
226 | emitError(getOperation()->getLoc()) |
227 | << "top-level op must be 'builtin.module'" ; |
228 | return signalPassFailure(); |
229 | } |
230 | |
231 | SmallVector<Operation *, 8> workList; |
232 | workList.push_back(Elt: getOperation()); |
233 | |
234 | do { |
235 | Operation *op = workList.pop_back_val(); |
236 | |
237 | for (Region ®ion : op->getRegions()) |
238 | if (!region.empty()) |
239 | if (failed(reduceOp(module: module, region))) |
240 | return signalPassFailure(); |
241 | |
242 | for (Region ®ion : op->getRegions()) |
243 | for (Operation &op : region.getOps()) |
244 | if (op.getNumRegions() != 0) |
245 | workList.push_back(Elt: &op); |
246 | } while (!workList.empty()); |
247 | } |
248 | |
249 | LogicalResult ReductionTreePass::reduceOp(ModuleOp module, Region ®ion) { |
250 | Tester test(testerName, testerArgs); |
251 | switch (traversalModeId) { |
252 | case TraversalMode::SinglePath: |
253 | return findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>( |
254 | module, region, reducerPatterns, test); |
255 | default: |
256 | return module.emitError() << "unsupported traversal mode detected" ; |
257 | } |
258 | } |
259 | |