| 1 | //===- ParallelLoopCollapsing.cpp - Pass collapsing parallel loop indices -===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/SCF/Transforms/Passes.h" |
| 10 | |
| 11 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| 12 | #include "mlir/Dialect/SCF/IR/SCF.h" |
| 13 | #include "mlir/Dialect/SCF/Utils/Utils.h" |
| 14 | #include "mlir/Transforms/RegionUtils.h" |
| 15 | #include "llvm/ADT/SmallSet.h" |
| 16 | #include "llvm/Support/CommandLine.h" |
| 17 | #include "llvm/Support/Debug.h" |
| 18 | |
| 19 | namespace mlir { |
| 20 | #define GEN_PASS_DEF_TESTSCFPARALLELLOOPCOLLAPSING |
| 21 | #include "mlir/Dialect/SCF/Transforms/Passes.h.inc" |
| 22 | } // namespace mlir |
| 23 | |
| 24 | #define DEBUG_TYPE "parallel-loop-collapsing" |
| 25 | |
| 26 | using namespace mlir; |
| 27 | |
| 28 | namespace { |
| 29 | struct TestSCFParallelLoopCollapsing |
| 30 | : public impl::TestSCFParallelLoopCollapsingBase< |
| 31 | TestSCFParallelLoopCollapsing> { |
| 32 | |
| 33 | void runOnOperation() override { |
| 34 | Operation *module = getOperation(); |
| 35 | |
| 36 | // The common case for GPU dialect will be simplifying the ParallelOp to 3 |
| 37 | // arguments, so we do that here to simplify things. |
| 38 | llvm::SmallVector<std::vector<unsigned>, 3> combinedLoops; |
| 39 | |
| 40 | // Gather the input args into the format required by |
| 41 | // `collapseParallelLoops`. |
| 42 | if (!clCollapsedIndices0.empty()) |
| 43 | combinedLoops.push_back(clCollapsedIndices0); |
| 44 | if (!clCollapsedIndices1.empty()) { |
| 45 | if (clCollapsedIndices0.empty()) { |
| 46 | llvm::errs() |
| 47 | << "collapsed-indices-1 specified but not collapsed-indices-0" ; |
| 48 | signalPassFailure(); |
| 49 | return; |
| 50 | } |
| 51 | combinedLoops.push_back(clCollapsedIndices1); |
| 52 | } |
| 53 | if (!clCollapsedIndices2.empty()) { |
| 54 | if (clCollapsedIndices1.empty()) { |
| 55 | llvm::errs() |
| 56 | << "collapsed-indices-2 specified but not collapsed-indices-1" ; |
| 57 | signalPassFailure(); |
| 58 | return; |
| 59 | } |
| 60 | combinedLoops.push_back(clCollapsedIndices2); |
| 61 | } |
| 62 | |
| 63 | if (combinedLoops.empty()) { |
| 64 | llvm::errs() << "No collapsed-indices were specified. This pass is only " |
| 65 | "for testing and does not automatically collapse all " |
| 66 | "parallel loops or similar." ; |
| 67 | signalPassFailure(); |
| 68 | return; |
| 69 | } |
| 70 | |
| 71 | // Confirm that the specified loops are [0,N) by testing that N values exist |
| 72 | // with the maximum value being N-1. |
| 73 | llvm::SmallSet<unsigned, 8> flattenedCombinedLoops; |
| 74 | unsigned maxCollapsedIndex = 0; |
| 75 | for (auto &loops : combinedLoops) { |
| 76 | for (auto &loop : loops) { |
| 77 | flattenedCombinedLoops.insert(loop); |
| 78 | maxCollapsedIndex = std::max(maxCollapsedIndex, loop); |
| 79 | } |
| 80 | } |
| 81 | |
| 82 | if (maxCollapsedIndex != flattenedCombinedLoops.size() - 1 || |
| 83 | !flattenedCombinedLoops.contains(maxCollapsedIndex)) { |
| 84 | llvm::errs() |
| 85 | << "collapsed-indices arguments must include all values [0,N)." ; |
| 86 | signalPassFailure(); |
| 87 | return; |
| 88 | } |
| 89 | |
| 90 | // Only apply the transformation on parallel loops where the specified |
| 91 | // transformation is valid, but do NOT early abort in the case of invalid |
| 92 | // loops. |
| 93 | IRRewriter rewriter(&getContext()); |
| 94 | module->walk([&](scf::ParallelOp op) { |
| 95 | if (flattenedCombinedLoops.size() != op.getNumLoops()) { |
| 96 | op.emitOpError("has " ) |
| 97 | << op.getNumLoops() |
| 98 | << " iter args while this limited functionality testing pass was " |
| 99 | "configured only for loops with exactly " |
| 100 | << flattenedCombinedLoops.size() << " iter args." ; |
| 101 | return; |
| 102 | } |
| 103 | collapseParallelLoops(rewriter, op, combinedLoops); |
| 104 | }); |
| 105 | } |
| 106 | }; |
| 107 | } // namespace |
| 108 | |
| 109 | std::unique_ptr<Pass> mlir::createTestSCFParallelLoopCollapsingPass() { |
| 110 | return std::make_unique<TestSCFParallelLoopCollapsing>(); |
| 111 | } |
| 112 | |