| 1 | //===- ReductionTreePass.cpp - ReductionTreePass Implementation -----------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file defines the Reduction Tree Pass class. It provides a framework for |
| 10 | // the implementation of different reduction passes in the MLIR Reduce tool. It |
| 11 | // allows for custom specification of the variant generation behavior. It |
| 12 | // implements methods that define the different possible traversals of the |
| 13 | // reduction tree. |
| 14 | // |
| 15 | //===----------------------------------------------------------------------===// |
| 16 | |
| 17 | #include "mlir/IR/DialectInterface.h" |
| 18 | #include "mlir/IR/OpDefinition.h" |
| 19 | #include "mlir/Reducer/Passes.h" |
| 20 | #include "mlir/Reducer/ReductionNode.h" |
| 21 | #include "mlir/Reducer/ReductionPatternInterface.h" |
| 22 | #include "mlir/Reducer/Tester.h" |
| 23 | #include "mlir/Rewrite/FrozenRewritePatternSet.h" |
| 24 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 25 | |
| 26 | #include "llvm/ADT/ArrayRef.h" |
| 27 | #include "llvm/ADT/SmallVector.h" |
| 28 | #include "llvm/Support/Allocator.h" |
| 29 | #include "llvm/Support/ManagedStatic.h" |
| 30 | |
| 31 | namespace mlir { |
| 32 | #define GEN_PASS_DEF_REDUCTIONTREEPASS |
| 33 | #include "mlir/Reducer/Passes.h.inc" |
| 34 | } // namespace mlir |
| 35 | |
| 36 | using namespace mlir; |
| 37 | |
| 38 | /// We implicitly number each operation in the region and if an operation's |
| 39 | /// number falls into rangeToKeep, we need to keep it and apply the given |
| 40 | /// rewrite patterns on it. |
| 41 | static void applyPatterns(Region ®ion, |
| 42 | const FrozenRewritePatternSet &patterns, |
| 43 | ArrayRef<ReductionNode::Range> rangeToKeep, |
| 44 | bool eraseOpNotInRange) { |
| 45 | std::vector<Operation *> opsNotInRange; |
| 46 | std::vector<Operation *> opsInRange; |
| 47 | size_t keepIndex = 0; |
| 48 | for (const auto &op : enumerate(region.getOps())) { |
| 49 | int index = op.index(); |
| 50 | if (keepIndex < rangeToKeep.size() && |
| 51 | index == rangeToKeep[keepIndex].second) |
| 52 | ++keepIndex; |
| 53 | if (keepIndex == rangeToKeep.size() || index < rangeToKeep[keepIndex].first) |
| 54 | opsNotInRange.push_back(&op.value()); |
| 55 | else |
| 56 | opsInRange.push_back(&op.value()); |
| 57 | } |
| 58 | |
| 59 | // `applyOpPatternsGreedily` with folding may erase the ops so we can't do the |
| 60 | // pattern matching in above iteration. Besides, erase op not-in-range may end |
| 61 | // up in invalid module, so `applyOpPatternsGreedily` with folding should come |
| 62 | // before that transform. |
| 63 | for (Operation *op : opsInRange) { |
| 64 | // `applyOpPatternsGreedily` with folding returns whether the op is |
| 65 | // converted. Omit it because we don't have expectation this reduction will |
| 66 | // be success or not. |
| 67 | (void)applyOpPatternsGreedily(op, patterns, |
| 68 | GreedyRewriteConfig().setStrictness( |
| 69 | GreedyRewriteStrictness::ExistingOps)); |
| 70 | } |
| 71 | |
| 72 | if (eraseOpNotInRange) |
| 73 | for (Operation *op : opsNotInRange) { |
| 74 | op->dropAllUses(); |
| 75 | op->erase(); |
| 76 | } |
| 77 | } |
| 78 | |
| 79 | /// We will apply the reducer patterns to the operations in the ranges specified |
| 80 | /// by ReductionNode. Note that we are not able to remove an operation without |
| 81 | /// replacing it with another valid operation. However, The validity of module |
| 82 | /// reduction is based on the Tester provided by the user and that means certain |
| 83 | /// invalid module is still interested by the use. Thus we provide an |
| 84 | /// alternative way to remove operations, which is using `eraseOpNotInRange` to |
| 85 | /// erase the operations not in the range specified by ReductionNode. |
| 86 | template <typename IteratorType> |
| 87 | static LogicalResult findOptimal(ModuleOp module, Region ®ion, |
| 88 | const FrozenRewritePatternSet &patterns, |
| 89 | const Tester &test, bool eraseOpNotInRange) { |
| 90 | std::pair<Tester::Interestingness, size_t> initStatus = |
| 91 | test.isInteresting(module); |
| 92 | // While exploring the reduction tree, we always branch from an interesting |
| 93 | // node. Thus the root node must be interesting. |
| 94 | if (initStatus.first != Tester::Interestingness::True) |
| 95 | return module.emitWarning() << "uninterested module will not be reduced" ; |
| 96 | |
| 97 | llvm::SpecificBumpPtrAllocator<ReductionNode> allocator; |
| 98 | |
| 99 | std::vector<ReductionNode::Range> ranges{ |
| 100 | {0, std::distance(first: region.op_begin(), last: region.op_end())}}; |
| 101 | |
| 102 | ReductionNode *root = allocator.Allocate(); |
| 103 | new (root) ReductionNode(nullptr, ranges, allocator); |
| 104 | // Duplicate the module for root node and locate the region in the copy. |
| 105 | if (failed(root->initialize(parentModule: module, parentRegion&: region))) |
| 106 | llvm_unreachable("unexpected initialization failure" ); |
| 107 | root->update(result: initStatus); |
| 108 | |
| 109 | ReductionNode *smallestNode = root; |
| 110 | IteratorType iter(root); |
| 111 | |
| 112 | while (iter != IteratorType::end()) { |
| 113 | ReductionNode ¤tNode = *iter; |
| 114 | Region &curRegion = currentNode.getRegion(); |
| 115 | |
| 116 | applyPatterns(region&: curRegion, patterns, rangeToKeep: currentNode.getRanges(), |
| 117 | eraseOpNotInRange); |
| 118 | currentNode.update(result: test.isInteresting(currentNode.getModule())); |
| 119 | |
| 120 | if (currentNode.isInteresting() == Tester::Interestingness::True && |
| 121 | currentNode.getSize() < smallestNode->getSize()) |
| 122 | smallestNode = ¤tNode; |
| 123 | |
| 124 | ++iter; |
| 125 | } |
| 126 | |
| 127 | // At here, we have found an optimal path to reduce the given region. Retrieve |
| 128 | // the path and apply the reducer to it. |
| 129 | SmallVector<ReductionNode *> trace; |
| 130 | ReductionNode *curNode = smallestNode; |
| 131 | trace.push_back(Elt: curNode); |
| 132 | while (curNode != root) { |
| 133 | curNode = curNode->getParent(); |
| 134 | trace.push_back(Elt: curNode); |
| 135 | } |
| 136 | |
| 137 | // Reduce the region through the optimal path. |
| 138 | while (!trace.empty()) { |
| 139 | ReductionNode *top = trace.pop_back_val(); |
| 140 | applyPatterns(region, patterns, rangeToKeep: top->getStartRanges(), eraseOpNotInRange); |
| 141 | } |
| 142 | |
| 143 | if (test.isInteresting(module).first != Tester::Interestingness::True) |
| 144 | llvm::report_fatal_error(reason: "Reduced module is not interesting" ); |
| 145 | if (test.isInteresting(module).second != smallestNode->getSize()) |
| 146 | llvm::report_fatal_error( |
| 147 | reason: "Reduced module doesn't have consistent size with smallestNode" ); |
| 148 | return success(); |
| 149 | } |
| 150 | |
| 151 | template <typename IteratorType> |
| 152 | static LogicalResult findOptimal(ModuleOp module, Region ®ion, |
| 153 | const FrozenRewritePatternSet &patterns, |
| 154 | const Tester &test) { |
| 155 | // We separate the reduction process into 2 steps, the first one is to erase |
| 156 | // redundant operations and the second one is to apply the reducer patterns. |
| 157 | |
| 158 | // In the first phase, we don't apply any patterns so that we only select the |
| 159 | // range of operations to keep to the module stay interesting. |
| 160 | if (failed(findOptimal<IteratorType>(module, region, /*patterns=*/{}, test, |
| 161 | /*eraseOpNotInRange=*/true))) |
| 162 | return failure(); |
| 163 | // In the second phase, we suppose that no operation is redundant, so we try |
| 164 | // to rewrite the operation into simpler form. |
| 165 | return findOptimal<IteratorType>(module, region, patterns, test, |
| 166 | /*eraseOpNotInRange=*/false); |
| 167 | } |
| 168 | |
| 169 | namespace { |
| 170 | |
| 171 | //===----------------------------------------------------------------------===// |
| 172 | // Reduction Pattern Interface Collection |
| 173 | //===----------------------------------------------------------------------===// |
| 174 | |
| 175 | class ReductionPatternInterfaceCollection |
| 176 | : public DialectInterfaceCollection<DialectReductionPatternInterface> { |
| 177 | public: |
| 178 | using Base::Base; |
| 179 | |
| 180 | // Collect the reduce patterns defined by each dialect. |
| 181 | void populateReductionPatterns(RewritePatternSet &pattern) const { |
| 182 | for (const DialectReductionPatternInterface &interface : *this) |
| 183 | interface.populateReductionPatterns(patterns&: pattern); |
| 184 | } |
| 185 | }; |
| 186 | |
| 187 | //===----------------------------------------------------------------------===// |
| 188 | // ReductionTreePass |
| 189 | //===----------------------------------------------------------------------===// |
| 190 | |
| 191 | /// This class defines the Reduction Tree Pass. It provides a framework to |
| 192 | /// to implement a reduction pass using a tree structure to keep track of the |
| 193 | /// generated reduced variants. |
| 194 | class ReductionTreePass |
| 195 | : public impl::ReductionTreePassBase<ReductionTreePass> { |
| 196 | public: |
| 197 | using Base::Base; |
| 198 | |
| 199 | LogicalResult initialize(MLIRContext *context) override; |
| 200 | |
| 201 | /// Runs the pass instance in the pass pipeline. |
| 202 | void runOnOperation() override; |
| 203 | |
| 204 | private: |
| 205 | LogicalResult reduceOp(ModuleOp module, Region ®ion); |
| 206 | |
| 207 | FrozenRewritePatternSet reducerPatterns; |
| 208 | }; |
| 209 | |
| 210 | } // namespace |
| 211 | |
| 212 | LogicalResult ReductionTreePass::initialize(MLIRContext *context) { |
| 213 | RewritePatternSet patterns(context); |
| 214 | ReductionPatternInterfaceCollection reducePatternCollection(context); |
| 215 | reducePatternCollection.populateReductionPatterns(pattern&: patterns); |
| 216 | reducerPatterns = std::move(patterns); |
| 217 | return success(); |
| 218 | } |
| 219 | |
| 220 | void ReductionTreePass::runOnOperation() { |
| 221 | Operation *topOperation = getOperation(); |
| 222 | while (topOperation->getParentOp() != nullptr) |
| 223 | topOperation = topOperation->getParentOp(); |
| 224 | ModuleOp module = dyn_cast<ModuleOp>(topOperation); |
| 225 | if (!module) { |
| 226 | emitError(getOperation()->getLoc()) |
| 227 | << "top-level op must be 'builtin.module'" ; |
| 228 | return signalPassFailure(); |
| 229 | } |
| 230 | |
| 231 | SmallVector<Operation *, 8> workList; |
| 232 | workList.push_back(Elt: getOperation()); |
| 233 | |
| 234 | do { |
| 235 | Operation *op = workList.pop_back_val(); |
| 236 | |
| 237 | for (Region ®ion : op->getRegions()) |
| 238 | if (!region.empty()) |
| 239 | if (failed(reduceOp(module: module, region))) |
| 240 | return signalPassFailure(); |
| 241 | |
| 242 | for (Region ®ion : op->getRegions()) |
| 243 | for (Operation &op : region.getOps()) |
| 244 | if (op.getNumRegions() != 0) |
| 245 | workList.push_back(Elt: &op); |
| 246 | } while (!workList.empty()); |
| 247 | } |
| 248 | |
| 249 | LogicalResult ReductionTreePass::reduceOp(ModuleOp module, Region ®ion) { |
| 250 | Tester test(testerName, testerArgs); |
| 251 | switch (traversalModeId) { |
| 252 | case TraversalMode::SinglePath: |
| 253 | return findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>( |
| 254 | module, region, reducerPatterns, test); |
| 255 | default: |
| 256 | return module.emitError() << "unsupported traversal mode detected" ; |
| 257 | } |
| 258 | } |
| 259 | |