1//===- ParallelLoopCollapsing.cpp - Pass collapsing parallel loop indices -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/SCF/Transforms/Passes.h"
10
11#include "mlir/Dialect/Affine/IR/AffineOps.h"
12#include "mlir/Dialect/SCF/IR/SCF.h"
13#include "mlir/Dialect/SCF/Utils/Utils.h"
14#include "mlir/Transforms/RegionUtils.h"
15#include "llvm/ADT/SmallSet.h"
16#include "llvm/Support/CommandLine.h"
17#include "llvm/Support/Debug.h"
18
19namespace mlir {
20#define GEN_PASS_DEF_TESTSCFPARALLELLOOPCOLLAPSING
21#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
22} // namespace mlir
23
24#define DEBUG_TYPE "parallel-loop-collapsing"
25
26using namespace mlir;
27
28namespace {
29struct TestSCFParallelLoopCollapsing
30 : public impl::TestSCFParallelLoopCollapsingBase<
31 TestSCFParallelLoopCollapsing> {
32
33 void runOnOperation() override {
34 Operation *module = getOperation();
35
36 // The common case for GPU dialect will be simplifying the ParallelOp to 3
37 // arguments, so we do that here to simplify things.
38 llvm::SmallVector<std::vector<unsigned>, 3> combinedLoops;
39
40 // Gather the input args into the format required by
41 // `collapseParallelLoops`.
42 if (!clCollapsedIndices0.empty())
43 combinedLoops.push_back(clCollapsedIndices0);
44 if (!clCollapsedIndices1.empty()) {
45 if (clCollapsedIndices0.empty()) {
46 llvm::errs()
47 << "collapsed-indices-1 specified but not collapsed-indices-0";
48 signalPassFailure();
49 return;
50 }
51 combinedLoops.push_back(clCollapsedIndices1);
52 }
53 if (!clCollapsedIndices2.empty()) {
54 if (clCollapsedIndices1.empty()) {
55 llvm::errs()
56 << "collapsed-indices-2 specified but not collapsed-indices-1";
57 signalPassFailure();
58 return;
59 }
60 combinedLoops.push_back(clCollapsedIndices2);
61 }
62
63 if (combinedLoops.empty()) {
64 llvm::errs() << "No collapsed-indices were specified. This pass is only "
65 "for testing and does not automatically collapse all "
66 "parallel loops or similar.";
67 signalPassFailure();
68 return;
69 }
70
71 // Confirm that the specified loops are [0,N) by testing that N values exist
72 // with the maximum value being N-1.
73 llvm::SmallSet<unsigned, 8> flattenedCombinedLoops;
74 unsigned maxCollapsedIndex = 0;
75 for (auto &loops : combinedLoops) {
76 for (auto &loop : loops) {
77 flattenedCombinedLoops.insert(loop);
78 maxCollapsedIndex = std::max(maxCollapsedIndex, loop);
79 }
80 }
81
82 if (maxCollapsedIndex != flattenedCombinedLoops.size() - 1 ||
83 !flattenedCombinedLoops.contains(maxCollapsedIndex)) {
84 llvm::errs()
85 << "collapsed-indices arguments must include all values [0,N).";
86 signalPassFailure();
87 return;
88 }
89
90 // Only apply the transformation on parallel loops where the specified
91 // transformation is valid, but do NOT early abort in the case of invalid
92 // loops.
93 IRRewriter rewriter(&getContext());
94 module->walk([&](scf::ParallelOp op) {
95 if (flattenedCombinedLoops.size() != op.getNumLoops()) {
96 op.emitOpError("has ")
97 << op.getNumLoops()
98 << " iter args while this limited functionality testing pass was "
99 "configured only for loops with exactly "
100 << flattenedCombinedLoops.size() << " iter args.";
101 return;
102 }
103 collapseParallelLoops(rewriter, op, combinedLoops);
104 });
105 }
106};
107} // namespace
108
109std::unique_ptr<Pass> mlir::createTestSCFParallelLoopCollapsingPass() {
110 return std::make_unique<TestSCFParallelLoopCollapsing>();
111}
112

Provided by KDAB

Privacy Policy
Learn to use CMake with our Intro Training
Find out more

source code of mlir/lib/Dialect/SCF/Transforms/ParallelLoopCollapsing.cpp