1 | //===- ParallelLoopCollapsing.cpp - Pass collapsing parallel loop indices -===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/SCF/Transforms/Passes.h" |
10 | |
11 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
12 | #include "mlir/Dialect/SCF/IR/SCF.h" |
13 | #include "mlir/Dialect/SCF/Utils/Utils.h" |
14 | #include "mlir/Transforms/RegionUtils.h" |
15 | #include "llvm/ADT/SmallSet.h" |
16 | #include "llvm/Support/CommandLine.h" |
17 | #include "llvm/Support/Debug.h" |
18 | |
19 | namespace mlir { |
20 | #define GEN_PASS_DEF_TESTSCFPARALLELLOOPCOLLAPSING |
21 | #include "mlir/Dialect/SCF/Transforms/Passes.h.inc" |
22 | } // namespace mlir |
23 | |
24 | #define DEBUG_TYPE "parallel-loop-collapsing" |
25 | |
26 | using namespace mlir; |
27 | |
28 | namespace { |
29 | struct TestSCFParallelLoopCollapsing |
30 | : public impl::TestSCFParallelLoopCollapsingBase< |
31 | TestSCFParallelLoopCollapsing> { |
32 | |
33 | void runOnOperation() override { |
34 | Operation *module = getOperation(); |
35 | |
36 | // The common case for GPU dialect will be simplifying the ParallelOp to 3 |
37 | // arguments, so we do that here to simplify things. |
38 | llvm::SmallVector<std::vector<unsigned>, 3> combinedLoops; |
39 | |
40 | // Gather the input args into the format required by |
41 | // `collapseParallelLoops`. |
42 | if (!clCollapsedIndices0.empty()) |
43 | combinedLoops.push_back(clCollapsedIndices0); |
44 | if (!clCollapsedIndices1.empty()) { |
45 | if (clCollapsedIndices0.empty()) { |
46 | llvm::errs() |
47 | << "collapsed-indices-1 specified but not collapsed-indices-0" ; |
48 | signalPassFailure(); |
49 | return; |
50 | } |
51 | combinedLoops.push_back(clCollapsedIndices1); |
52 | } |
53 | if (!clCollapsedIndices2.empty()) { |
54 | if (clCollapsedIndices1.empty()) { |
55 | llvm::errs() |
56 | << "collapsed-indices-2 specified but not collapsed-indices-1" ; |
57 | signalPassFailure(); |
58 | return; |
59 | } |
60 | combinedLoops.push_back(clCollapsedIndices2); |
61 | } |
62 | |
63 | if (combinedLoops.empty()) { |
64 | llvm::errs() << "No collapsed-indices were specified. This pass is only " |
65 | "for testing and does not automatically collapse all " |
66 | "parallel loops or similar." ; |
67 | signalPassFailure(); |
68 | return; |
69 | } |
70 | |
71 | // Confirm that the specified loops are [0,N) by testing that N values exist |
72 | // with the maximum value being N-1. |
73 | llvm::SmallSet<unsigned, 8> flattenedCombinedLoops; |
74 | unsigned maxCollapsedIndex = 0; |
75 | for (auto &loops : combinedLoops) { |
76 | for (auto &loop : loops) { |
77 | flattenedCombinedLoops.insert(loop); |
78 | maxCollapsedIndex = std::max(maxCollapsedIndex, loop); |
79 | } |
80 | } |
81 | |
82 | if (maxCollapsedIndex != flattenedCombinedLoops.size() - 1 || |
83 | !flattenedCombinedLoops.contains(maxCollapsedIndex)) { |
84 | llvm::errs() |
85 | << "collapsed-indices arguments must include all values [0,N)." ; |
86 | signalPassFailure(); |
87 | return; |
88 | } |
89 | |
90 | // Only apply the transformation on parallel loops where the specified |
91 | // transformation is valid, but do NOT early abort in the case of invalid |
92 | // loops. |
93 | IRRewriter rewriter(&getContext()); |
94 | module->walk([&](scf::ParallelOp op) { |
95 | if (flattenedCombinedLoops.size() != op.getNumLoops()) { |
96 | op.emitOpError("has " ) |
97 | << op.getNumLoops() |
98 | << " iter args while this limited functionality testing pass was " |
99 | "configured only for loops with exactly " |
100 | << flattenedCombinedLoops.size() << " iter args." ; |
101 | return; |
102 | } |
103 | collapseParallelLoops(rewriter, op, combinedLoops); |
104 | }); |
105 | } |
106 | }; |
107 | } // namespace |
108 | |
109 | std::unique_ptr<Pass> mlir::createTestSCFParallelLoopCollapsingPass() { |
110 | return std::make_unique<TestSCFParallelLoopCollapsing>(); |
111 | } |
112 | |