1 | //===- NestedMatcher.cpp - NestedMatcher Impl ----------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include <utility> |
10 | |
11 | #include "mlir/Dialect/Affine/Analysis/NestedMatcher.h" |
12 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
13 | |
14 | #include "llvm/ADT/ArrayRef.h" |
15 | #include "llvm/ADT/STLExtras.h" |
16 | #include "llvm/Support/Allocator.h" |
17 | #include "llvm/Support/raw_ostream.h" |
18 | |
19 | using namespace mlir; |
20 | using namespace mlir::affine; |
21 | |
22 | llvm::BumpPtrAllocator *&NestedMatch::allocator() { |
23 | thread_local llvm::BumpPtrAllocator *allocator = nullptr; |
24 | return allocator; |
25 | } |
26 | |
27 | NestedMatch NestedMatch::build(Operation *operation, |
28 | ArrayRef<NestedMatch> nestedMatches) { |
29 | auto *result = allocator()->Allocate<NestedMatch>(); |
30 | auto *children = allocator()->Allocate<NestedMatch>(Num: nestedMatches.size()); |
31 | std::uninitialized_copy(first: nestedMatches.begin(), last: nestedMatches.end(), result: children); |
32 | new (result) NestedMatch(); |
33 | result->matchedOperation = operation; |
34 | result->matchedChildren = |
35 | ArrayRef<NestedMatch>(children, nestedMatches.size()); |
36 | return *result; |
37 | } |
38 | |
39 | llvm::BumpPtrAllocator *&NestedPattern::allocator() { |
40 | thread_local llvm::BumpPtrAllocator *allocator = nullptr; |
41 | return allocator; |
42 | } |
43 | |
44 | void NestedPattern::copyNestedToThis(ArrayRef<NestedPattern> nested) { |
45 | if (nested.empty()) |
46 | return; |
47 | |
48 | auto *newNested = allocator()->Allocate<NestedPattern>(Num: nested.size()); |
49 | std::uninitialized_copy(first: nested.begin(), last: nested.end(), result: newNested); |
50 | nestedPatterns = ArrayRef<NestedPattern>(newNested, nested.size()); |
51 | } |
52 | |
53 | void NestedPattern::freeNested() { |
54 | for (const auto &p : nestedPatterns) |
55 | p.~NestedPattern(); |
56 | } |
57 | |
58 | NestedPattern::NestedPattern(ArrayRef<NestedPattern> nested, |
59 | FilterFunctionType filter) |
60 | : filter(std::move(filter)), skip(nullptr) { |
61 | copyNestedToThis(nested); |
62 | } |
63 | |
64 | NestedPattern::NestedPattern(const NestedPattern &other) |
65 | : filter(other.filter), skip(other.skip) { |
66 | copyNestedToThis(nested: other.nestedPatterns); |
67 | } |
68 | |
69 | NestedPattern &NestedPattern::operator=(const NestedPattern &other) { |
70 | freeNested(); |
71 | filter = other.filter; |
72 | skip = other.skip; |
73 | copyNestedToThis(nested: other.nestedPatterns); |
74 | return *this; |
75 | } |
76 | |
77 | unsigned NestedPattern::getDepth() const { |
78 | if (nestedPatterns.empty()) { |
79 | return 1; |
80 | } |
81 | unsigned depth = 0; |
82 | for (auto &c : nestedPatterns) { |
83 | depth = std::max(a: depth, b: c.getDepth()); |
84 | } |
85 | return depth + 1; |
86 | } |
87 | |
88 | /// Matches a single operation in the following way: |
89 | /// 1. checks the kind of operation against the matcher, if different then |
90 | /// there is no match; |
91 | /// 2. calls the customizable filter function to refine the single operation |
92 | /// match with extra semantic constraints; |
93 | /// 3. if all is good, recursively matches the nested patterns; |
94 | /// 4. if all nested match then the single operation matches too and is |
95 | /// appended to the list of matches; |
96 | /// 5. TODO: Optionally applies actions (lambda), in which case we will want |
97 | /// to traverse in post-order DFS to avoid invalidating iterators. |
98 | void NestedPattern::matchOne(Operation *op, |
99 | SmallVectorImpl<NestedMatch> *matches) { |
100 | if (skip == op) { |
101 | return; |
102 | } |
103 | // Local custom filter function |
104 | if (!filter(*op)) { |
105 | return; |
106 | } |
107 | |
108 | if (nestedPatterns.empty()) { |
109 | SmallVector<NestedMatch, 8> nestedMatches; |
110 | matches->push_back(Elt: NestedMatch::build(operation: op, nestedMatches)); |
111 | return; |
112 | } |
113 | // Take a copy of each nested pattern so we can match it. |
114 | for (auto nestedPattern : nestedPatterns) { |
115 | SmallVector<NestedMatch, 8> nestedMatches; |
116 | // Skip elem in the walk immediately following. Without this we would |
117 | // essentially need to reimplement walk here. |
118 | nestedPattern.skip = op; |
119 | nestedPattern.match(op, matches: &nestedMatches); |
120 | // If we could not match even one of the specified nestedPattern, early exit |
121 | // as this whole branch is not a match. |
122 | if (nestedMatches.empty()) { |
123 | return; |
124 | } |
125 | matches->push_back(Elt: NestedMatch::build(operation: op, nestedMatches)); |
126 | } |
127 | } |
128 | |
129 | static bool isAffineForOp(Operation &op) { return isa<AffineForOp>(op); } |
130 | |
131 | static bool isAffineIfOp(Operation &op) { return isa<AffineIfOp>(op); } |
132 | |
133 | namespace mlir { |
134 | namespace affine { |
135 | namespace matcher { |
136 | |
137 | NestedPattern Op(FilterFunctionType filter) { |
138 | return NestedPattern({}, std::move(filter)); |
139 | } |
140 | |
141 | NestedPattern If(const NestedPattern &child) { |
142 | return NestedPattern(child, isAffineIfOp); |
143 | } |
144 | NestedPattern If(const FilterFunctionType &filter, const NestedPattern &child) { |
145 | return NestedPattern(child, [filter](Operation &op) { |
146 | return isAffineIfOp(op) && filter(op); |
147 | }); |
148 | } |
149 | NestedPattern If(ArrayRef<NestedPattern> nested) { |
150 | return NestedPattern(nested, isAffineIfOp); |
151 | } |
152 | NestedPattern If(const FilterFunctionType &filter, |
153 | ArrayRef<NestedPattern> nested) { |
154 | return NestedPattern(nested, [filter](Operation &op) { |
155 | return isAffineIfOp(op) && filter(op); |
156 | }); |
157 | } |
158 | |
159 | NestedPattern For(const NestedPattern &child) { |
160 | return NestedPattern(child, isAffineForOp); |
161 | } |
162 | NestedPattern For(const FilterFunctionType &filter, |
163 | const NestedPattern &child) { |
164 | return NestedPattern( |
165 | child, [=](Operation &op) { return isAffineForOp(op) && filter(op); }); |
166 | } |
167 | NestedPattern For(ArrayRef<NestedPattern> nested) { |
168 | return NestedPattern(nested, isAffineForOp); |
169 | } |
170 | NestedPattern For(const FilterFunctionType &filter, |
171 | ArrayRef<NestedPattern> nested) { |
172 | return NestedPattern( |
173 | nested, [=](Operation &op) { return isAffineForOp(op) && filter(op); }); |
174 | } |
175 | |
176 | bool isLoadOrStore(Operation &op) { |
177 | return isa<AffineLoadOp, AffineStoreOp>(op); |
178 | } |
179 | |
180 | } // namespace matcher |
181 | } // namespace affine |
182 | } // namespace mlir |
183 | |