1//===- NestedMatcher.cpp - NestedMatcher Impl ----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <utility>
10
11#include "mlir/Dialect/Affine/Analysis/NestedMatcher.h"
12#include "mlir/Dialect/Affine/IR/AffineOps.h"
13
14#include "llvm/ADT/ArrayRef.h"
15#include "llvm/ADT/STLExtras.h"
16#include "llvm/Support/Allocator.h"
17#include "llvm/Support/raw_ostream.h"
18
19using namespace mlir;
20using namespace mlir::affine;
21
22llvm::BumpPtrAllocator *&NestedMatch::allocator() {
23 thread_local llvm::BumpPtrAllocator *allocator = nullptr;
24 return allocator;
25}
26
27NestedMatch NestedMatch::build(Operation *operation,
28 ArrayRef<NestedMatch> nestedMatches) {
29 auto *result = allocator()->Allocate<NestedMatch>();
30 auto *children = allocator()->Allocate<NestedMatch>(Num: nestedMatches.size());
31 std::uninitialized_copy(first: nestedMatches.begin(), last: nestedMatches.end(), result: children);
32 new (result) NestedMatch();
33 result->matchedOperation = operation;
34 result->matchedChildren =
35 ArrayRef<NestedMatch>(children, nestedMatches.size());
36 return *result;
37}
38
39llvm::BumpPtrAllocator *&NestedPattern::allocator() {
40 thread_local llvm::BumpPtrAllocator *allocator = nullptr;
41 return allocator;
42}
43
44void NestedPattern::copyNestedToThis(ArrayRef<NestedPattern> nested) {
45 if (nested.empty())
46 return;
47
48 auto *newNested = allocator()->Allocate<NestedPattern>(Num: nested.size());
49 std::uninitialized_copy(first: nested.begin(), last: nested.end(), result: newNested);
50 nestedPatterns = ArrayRef<NestedPattern>(newNested, nested.size());
51}
52
53void NestedPattern::freeNested() {
54 for (const auto &p : nestedPatterns)
55 p.~NestedPattern();
56}
57
58NestedPattern::NestedPattern(ArrayRef<NestedPattern> nested,
59 FilterFunctionType filter)
60 : filter(std::move(filter)), skip(nullptr) {
61 copyNestedToThis(nested);
62}
63
64NestedPattern::NestedPattern(const NestedPattern &other)
65 : filter(other.filter), skip(other.skip) {
66 copyNestedToThis(nested: other.nestedPatterns);
67}
68
69NestedPattern &NestedPattern::operator=(const NestedPattern &other) {
70 freeNested();
71 filter = other.filter;
72 skip = other.skip;
73 copyNestedToThis(nested: other.nestedPatterns);
74 return *this;
75}
76
77unsigned NestedPattern::getDepth() const {
78 if (nestedPatterns.empty()) {
79 return 1;
80 }
81 unsigned depth = 0;
82 for (auto &c : nestedPatterns) {
83 depth = std::max(a: depth, b: c.getDepth());
84 }
85 return depth + 1;
86}
87
88/// Matches a single operation in the following way:
89/// 1. checks the kind of operation against the matcher, if different then
90/// there is no match;
91/// 2. calls the customizable filter function to refine the single operation
92/// match with extra semantic constraints;
93/// 3. if all is good, recursively matches the nested patterns;
94/// 4. if all nested match then the single operation matches too and is
95/// appended to the list of matches;
96/// 5. TODO: Optionally applies actions (lambda), in which case we will want
97/// to traverse in post-order DFS to avoid invalidating iterators.
98void NestedPattern::matchOne(Operation *op,
99 SmallVectorImpl<NestedMatch> *matches) {
100 if (skip == op) {
101 return;
102 }
103 // Local custom filter function
104 if (!filter(*op)) {
105 return;
106 }
107
108 if (nestedPatterns.empty()) {
109 SmallVector<NestedMatch, 8> nestedMatches;
110 matches->push_back(Elt: NestedMatch::build(operation: op, nestedMatches));
111 return;
112 }
113 // Take a copy of each nested pattern so we can match it.
114 for (auto nestedPattern : nestedPatterns) {
115 SmallVector<NestedMatch, 8> nestedMatches;
116 // Skip elem in the walk immediately following. Without this we would
117 // essentially need to reimplement walk here.
118 nestedPattern.skip = op;
119 nestedPattern.match(op, matches: &nestedMatches);
120 // If we could not match even one of the specified nestedPattern, early exit
121 // as this whole branch is not a match.
122 if (nestedMatches.empty()) {
123 return;
124 }
125 matches->push_back(Elt: NestedMatch::build(operation: op, nestedMatches));
126 }
127}
128
129static bool isAffineForOp(Operation &op) { return isa<AffineForOp>(op); }
130
131static bool isAffineIfOp(Operation &op) { return isa<AffineIfOp>(op); }
132
133namespace mlir {
134namespace affine {
135namespace matcher {
136
137NestedPattern Op(FilterFunctionType filter) {
138 return NestedPattern({}, std::move(filter));
139}
140
141NestedPattern If(const NestedPattern &child) {
142 return NestedPattern(child, isAffineIfOp);
143}
144NestedPattern If(const FilterFunctionType &filter, const NestedPattern &child) {
145 return NestedPattern(child, [filter](Operation &op) {
146 return isAffineIfOp(op) && filter(op);
147 });
148}
149NestedPattern If(ArrayRef<NestedPattern> nested) {
150 return NestedPattern(nested, isAffineIfOp);
151}
152NestedPattern If(const FilterFunctionType &filter,
153 ArrayRef<NestedPattern> nested) {
154 return NestedPattern(nested, [filter](Operation &op) {
155 return isAffineIfOp(op) && filter(op);
156 });
157}
158
159NestedPattern For(const NestedPattern &child) {
160 return NestedPattern(child, isAffineForOp);
161}
162NestedPattern For(const FilterFunctionType &filter,
163 const NestedPattern &child) {
164 return NestedPattern(
165 child, [=](Operation &op) { return isAffineForOp(op) && filter(op); });
166}
167NestedPattern For(ArrayRef<NestedPattern> nested) {
168 return NestedPattern(nested, isAffineForOp);
169}
170NestedPattern For(const FilterFunctionType &filter,
171 ArrayRef<NestedPattern> nested) {
172 return NestedPattern(
173 nested, [=](Operation &op) { return isAffineForOp(op) && filter(op); });
174}
175
176bool isLoadOrStore(Operation &op) {
177 return isa<AffineLoadOp, AffineStoreOp>(op);
178}
179
180} // namespace matcher
181} // namespace affine
182} // namespace mlir
183

source code of mlir/lib/Dialect/Affine/Analysis/NestedMatcher.cpp