1 | //===- TestTopologicalSort.cpp - Pass to test topological sort analysis ---===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/IR/Builders.h" |
10 | #include "mlir/IR/BuiltinOps.h" |
11 | #include "mlir/Pass/Pass.h" |
12 | #include "mlir/Transforms/TopologicalSortUtils.h" |
13 | |
14 | using namespace mlir; |
15 | |
16 | namespace { |
17 | struct TestTopologicalSortAnalysisPass |
18 | : public PassWrapper<TestTopologicalSortAnalysisPass, |
19 | OperationPass<ModuleOp>> { |
20 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTopologicalSortAnalysisPass) |
21 | |
22 | StringRef getArgument() const final { |
23 | return "test-topological-sort-analysis"; |
24 | } |
25 | StringRef getDescription() const final { |
26 | return "Test topological sorting of ops"; |
27 | } |
28 | |
29 | void runOnOperation() override { |
30 | Operation *op = getOperation(); |
31 | OpBuilder builder(op->getContext()); |
32 | |
33 | WalkResult result = op->walk(callback: [&](Operation *root) { |
34 | if (!root->hasAttr(name: "root")) |
35 | return WalkResult::advance(); |
36 | |
37 | SmallVector<Operation *> selectedOps; |
38 | root->walk(callback: [&](Operation *selected) { |
39 | if (!selected->hasAttr(name: "selected")) |
40 | return WalkResult::advance(); |
41 | if (root->hasAttr(name: "ordered")) { |
42 | // If the root has an "ordered" attribute, we fill the selectedOps |
43 | // vector in a certain order. |
44 | int64_t pos = |
45 | cast<IntegerAttr>(selected->getDiscardableAttr(name: "selected")) |
46 | .getInt(); |
47 | if (pos >= static_cast<int64_t>(selectedOps.size())) |
48 | selectedOps.append(NumInputs: pos + 1 - selectedOps.size(), Elt: nullptr); |
49 | selectedOps[pos] = selected; |
50 | } else { |
51 | selectedOps.push_back(Elt: selected); |
52 | } |
53 | return WalkResult::advance(); |
54 | }); |
55 | |
56 | if (llvm::find(Range&: selectedOps, Val: nullptr) != selectedOps.end()) { |
57 | root->emitError(message: "invalid test case: some indices are missing among the " |
58 | "selected ops"); |
59 | return WalkResult::skip(); |
60 | } |
61 | |
62 | if (!computeTopologicalSorting(ops: selectedOps)) { |
63 | root->emitError(message: "could not schedule all ops"); |
64 | return WalkResult::skip(); |
65 | } |
66 | |
67 | for (const auto &it : llvm::enumerate(First&: selectedOps)) |
68 | it.value()->setAttr("pos", builder.getIndexAttr(it.index())); |
69 | |
70 | return WalkResult::advance(); |
71 | }); |
72 | |
73 | if (result.wasSkipped()) |
74 | signalPassFailure(); |
75 | } |
76 | }; |
77 | } // namespace |
78 | |
79 | namespace mlir { |
80 | namespace test { |
81 | void registerTestTopologicalSortAnalysisPass() { |
82 | PassRegistration<TestTopologicalSortAnalysisPass>(); |
83 | } |
84 | } // namespace test |
85 | } // namespace mlir |
86 |