1//===- ReductionTreePass.cpp - ReductionTreePass Implementation -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the Reduction Tree Pass class. It provides a framework for
10// the implementation of different reduction passes in the MLIR Reduce tool. It
11// allows for custom specification of the variant generation behavior. It
12// implements methods that define the different possible traversals of the
13// reduction tree.
14//
15//===----------------------------------------------------------------------===//
16
17#include "mlir/IR/DialectInterface.h"
18#include "mlir/IR/OpDefinition.h"
19#include "mlir/Reducer/Passes.h"
20#include "mlir/Reducer/ReductionNode.h"
21#include "mlir/Reducer/ReductionPatternInterface.h"
22#include "mlir/Reducer/Tester.h"
23#include "mlir/Rewrite/FrozenRewritePatternSet.h"
24#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
25
26#include "llvm/ADT/ArrayRef.h"
27#include "llvm/ADT/SmallVector.h"
28#include "llvm/Support/Allocator.h"
29#include "llvm/Support/ManagedStatic.h"
30
31namespace mlir {
32#define GEN_PASS_DEF_REDUCTIONTREEPASS
33#include "mlir/Reducer/Passes.h.inc"
34} // namespace mlir
35
36using namespace mlir;
37
38/// We implicitly number each operation in the region and if an operation's
39/// number falls into rangeToKeep, we need to keep it and apply the given
40/// rewrite patterns on it.
41static void applyPatterns(Region &region,
42 const FrozenRewritePatternSet &patterns,
43 ArrayRef<ReductionNode::Range> rangeToKeep,
44 bool eraseOpNotInRange) {
45 std::vector<Operation *> opsNotInRange;
46 std::vector<Operation *> opsInRange;
47 size_t keepIndex = 0;
48 for (const auto &op : enumerate(region.getOps())) {
49 int index = op.index();
50 if (keepIndex < rangeToKeep.size() &&
51 index == rangeToKeep[keepIndex].second)
52 ++keepIndex;
53 if (keepIndex == rangeToKeep.size() || index < rangeToKeep[keepIndex].first)
54 opsNotInRange.push_back(&op.value());
55 else
56 opsInRange.push_back(&op.value());
57 }
58
59 // `applyOpPatternsGreedily` with folding may erase the ops so we can't do the
60 // pattern matching in above iteration. Besides, erase op not-in-range may end
61 // up in invalid module, so `applyOpPatternsGreedily` with folding should come
62 // before that transform.
63 for (Operation *op : opsInRange) {
64 // `applyOpPatternsGreedily` with folding returns whether the op is
65 // converted. Omit it because we don't have expectation this reduction will
66 // be success or not.
67 (void)applyOpPatternsGreedily(op, patterns,
68 GreedyRewriteConfig().setStrictness(
69 GreedyRewriteStrictness::ExistingOps));
70 }
71
72 if (eraseOpNotInRange)
73 for (Operation *op : opsNotInRange) {
74 op->dropAllUses();
75 op->erase();
76 }
77}
78
79/// We will apply the reducer patterns to the operations in the ranges specified
80/// by ReductionNode. Note that we are not able to remove an operation without
81/// replacing it with another valid operation. However, The validity of module
82/// reduction is based on the Tester provided by the user and that means certain
83/// invalid module is still interested by the use. Thus we provide an
84/// alternative way to remove operations, which is using `eraseOpNotInRange` to
85/// erase the operations not in the range specified by ReductionNode.
86template <typename IteratorType>
87static LogicalResult findOptimal(ModuleOp module, Region &region,
88 const FrozenRewritePatternSet &patterns,
89 const Tester &test, bool eraseOpNotInRange) {
90 std::pair<Tester::Interestingness, size_t> initStatus =
91 test.isInteresting(module);
92 // While exploring the reduction tree, we always branch from an interesting
93 // node. Thus the root node must be interesting.
94 if (initStatus.first != Tester::Interestingness::True)
95 return module.emitWarning() << "uninterested module will not be reduced";
96
97 llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
98
99 std::vector<ReductionNode::Range> ranges{
100 {0, std::distance(first: region.op_begin(), last: region.op_end())}};
101
102 ReductionNode *root = allocator.Allocate();
103 new (root) ReductionNode(nullptr, ranges, allocator);
104 // Duplicate the module for root node and locate the region in the copy.
105 if (failed(root->initialize(parentModule: module, parentRegion&: region)))
106 llvm_unreachable("unexpected initialization failure");
107 root->update(result: initStatus);
108
109 ReductionNode *smallestNode = root;
110 IteratorType iter(root);
111
112 while (iter != IteratorType::end()) {
113 ReductionNode &currentNode = *iter;
114 Region &curRegion = currentNode.getRegion();
115
116 applyPatterns(region&: curRegion, patterns, rangeToKeep: currentNode.getRanges(),
117 eraseOpNotInRange);
118 currentNode.update(result: test.isInteresting(currentNode.getModule()));
119
120 if (currentNode.isInteresting() == Tester::Interestingness::True &&
121 currentNode.getSize() < smallestNode->getSize())
122 smallestNode = &currentNode;
123
124 ++iter;
125 }
126
127 // At here, we have found an optimal path to reduce the given region. Retrieve
128 // the path and apply the reducer to it.
129 SmallVector<ReductionNode *> trace;
130 ReductionNode *curNode = smallestNode;
131 trace.push_back(Elt: curNode);
132 while (curNode != root) {
133 curNode = curNode->getParent();
134 trace.push_back(Elt: curNode);
135 }
136
137 // Reduce the region through the optimal path.
138 while (!trace.empty()) {
139 ReductionNode *top = trace.pop_back_val();
140 applyPatterns(region, patterns, rangeToKeep: top->getStartRanges(), eraseOpNotInRange);
141 }
142
143 if (test.isInteresting(module).first != Tester::Interestingness::True)
144 llvm::report_fatal_error(reason: "Reduced module is not interesting");
145 if (test.isInteresting(module).second != smallestNode->getSize())
146 llvm::report_fatal_error(
147 reason: "Reduced module doesn't have consistent size with smallestNode");
148 return success();
149}
150
151template <typename IteratorType>
152static LogicalResult findOptimal(ModuleOp module, Region &region,
153 const FrozenRewritePatternSet &patterns,
154 const Tester &test) {
155 // We separate the reduction process into 2 steps, the first one is to erase
156 // redundant operations and the second one is to apply the reducer patterns.
157
158 // In the first phase, we don't apply any patterns so that we only select the
159 // range of operations to keep to the module stay interesting.
160 if (failed(findOptimal<IteratorType>(module, region, /*patterns=*/{}, test,
161 /*eraseOpNotInRange=*/true)))
162 return failure();
163 // In the second phase, we suppose that no operation is redundant, so we try
164 // to rewrite the operation into simpler form.
165 return findOptimal<IteratorType>(module, region, patterns, test,
166 /*eraseOpNotInRange=*/false);
167}
168
169namespace {
170
171//===----------------------------------------------------------------------===//
172// Reduction Pattern Interface Collection
173//===----------------------------------------------------------------------===//
174
175class ReductionPatternInterfaceCollection
176 : public DialectInterfaceCollection<DialectReductionPatternInterface> {
177public:
178 using Base::Base;
179
180 // Collect the reduce patterns defined by each dialect.
181 void populateReductionPatterns(RewritePatternSet &pattern) const {
182 for (const DialectReductionPatternInterface &interface : *this)
183 interface.populateReductionPatterns(patterns&: pattern);
184 }
185};
186
187//===----------------------------------------------------------------------===//
188// ReductionTreePass
189//===----------------------------------------------------------------------===//
190
191/// This class defines the Reduction Tree Pass. It provides a framework to
192/// to implement a reduction pass using a tree structure to keep track of the
193/// generated reduced variants.
194class ReductionTreePass
195 : public impl::ReductionTreePassBase<ReductionTreePass> {
196public:
197 using Base::Base;
198
199 LogicalResult initialize(MLIRContext *context) override;
200
201 /// Runs the pass instance in the pass pipeline.
202 void runOnOperation() override;
203
204private:
205 LogicalResult reduceOp(ModuleOp module, Region &region);
206
207 FrozenRewritePatternSet reducerPatterns;
208};
209
210} // namespace
211
212LogicalResult ReductionTreePass::initialize(MLIRContext *context) {
213 RewritePatternSet patterns(context);
214 ReductionPatternInterfaceCollection reducePatternCollection(context);
215 reducePatternCollection.populateReductionPatterns(pattern&: patterns);
216 reducerPatterns = std::move(patterns);
217 return success();
218}
219
220void ReductionTreePass::runOnOperation() {
221 Operation *topOperation = getOperation();
222 while (topOperation->getParentOp() != nullptr)
223 topOperation = topOperation->getParentOp();
224 ModuleOp module = dyn_cast<ModuleOp>(topOperation);
225 if (!module) {
226 emitError(getOperation()->getLoc())
227 << "top-level op must be 'builtin.module'";
228 return signalPassFailure();
229 }
230
231 SmallVector<Operation *, 8> workList;
232 workList.push_back(Elt: getOperation());
233
234 do {
235 Operation *op = workList.pop_back_val();
236
237 for (Region &region : op->getRegions())
238 if (!region.empty())
239 if (failed(reduceOp(module: module, region)))
240 return signalPassFailure();
241
242 for (Region &region : op->getRegions())
243 for (Operation &op : region.getOps())
244 if (op.getNumRegions() != 0)
245 workList.push_back(Elt: &op);
246 } while (!workList.empty());
247}
248
249LogicalResult ReductionTreePass::reduceOp(ModuleOp module, Region &region) {
250 Tester test(testerName, testerArgs);
251 switch (traversalModeId) {
252 case TraversalMode::SinglePath:
253 return findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>(
254 module, region, reducerPatterns, test);
255 default:
256 return module.emitError() << "unsupported traversal mode detected";
257 }
258}
259

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/Reducer/ReductionTreePass.cpp