1 | //===- Predicate.cpp - Predicate class ------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Wrapper around predicates defined in TableGen. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/TableGen/Predicate.h" |
14 | #include "llvm/ADT/SmallPtrSet.h" |
15 | #include "llvm/ADT/StringExtras.h" |
16 | #include "llvm/ADT/StringSwitch.h" |
17 | #include "llvm/Support/FormatVariadic.h" |
18 | #include "llvm/TableGen/Error.h" |
19 | #include "llvm/TableGen/Record.h" |
20 | |
21 | using namespace mlir; |
22 | using namespace tblgen; |
23 | |
24 | // Construct a Predicate from a record. |
25 | Pred::Pred(const llvm::Record *record) : def(record) { |
26 | assert(def->isSubClassOf("Pred" ) && |
27 | "must be a subclass of TableGen 'Pred' class" ); |
28 | } |
29 | |
30 | // Construct a Predicate from an initializer. |
31 | Pred::Pred(const llvm::Init *init) { |
32 | if (const auto *defInit = dyn_cast_or_null<llvm::DefInit>(Val: init)) |
33 | def = defInit->getDef(); |
34 | } |
35 | |
36 | std::string Pred::getCondition() const { |
37 | // Static dispatch to subclasses. |
38 | if (def->isSubClassOf(Name: "CombinedPred" )) |
39 | return static_cast<const CombinedPred *>(this)->getConditionImpl(); |
40 | if (def->isSubClassOf(Name: "CPred" )) |
41 | return static_cast<const CPred *>(this)->getConditionImpl(); |
42 | llvm_unreachable("Pred::getCondition must be overridden in subclasses" ); |
43 | } |
44 | |
45 | bool Pred::isCombined() const { |
46 | return def && def->isSubClassOf(Name: "CombinedPred" ); |
47 | } |
48 | |
49 | ArrayRef<SMLoc> Pred::getLoc() const { return def->getLoc(); } |
50 | |
51 | CPred::CPred(const llvm::Record *record) : Pred(record) { |
52 | assert(def->isSubClassOf("CPred" ) && |
53 | "must be a subclass of Tablegen 'CPred' class" ); |
54 | } |
55 | |
56 | CPred::CPred(const llvm::Init *init) : Pred(init) { |
57 | assert((!def || def->isSubClassOf("CPred" )) && |
58 | "must be a subclass of Tablegen 'CPred' class" ); |
59 | } |
60 | |
61 | // Get condition of the C Predicate. |
62 | std::string CPred::getConditionImpl() const { |
63 | assert(!isNull() && "null predicate does not have a condition" ); |
64 | return std::string(def->getValueAsString(FieldName: "predExpr" )); |
65 | } |
66 | |
67 | CombinedPred::CombinedPred(const llvm::Record *record) : Pred(record) { |
68 | assert(def->isSubClassOf("CombinedPred" ) && |
69 | "must be a subclass of Tablegen 'CombinedPred' class" ); |
70 | } |
71 | |
72 | CombinedPred::CombinedPred(const llvm::Init *init) : Pred(init) { |
73 | assert((!def || def->isSubClassOf("CombinedPred" )) && |
74 | "must be a subclass of Tablegen 'CombinedPred' class" ); |
75 | } |
76 | |
77 | const llvm::Record *CombinedPred::getCombinerDef() const { |
78 | assert(def->getValue("kind" ) && "CombinedPred must have a value 'kind'" ); |
79 | return def->getValueAsDef(FieldName: "kind" ); |
80 | } |
81 | |
82 | std::vector<llvm::Record *> CombinedPred::getChildren() const { |
83 | assert(def->getValue("children" ) && |
84 | "CombinedPred must have a value 'children'" ); |
85 | return def->getValueAsListOfDefs(FieldName: "children" ); |
86 | } |
87 | |
88 | namespace { |
89 | // Kinds of nodes in a logical predicate tree. |
90 | enum class PredCombinerKind { |
91 | Leaf, |
92 | And, |
93 | Or, |
94 | Not, |
95 | SubstLeaves, |
96 | Concat, |
97 | // Special kinds that are used in simplification. |
98 | False, |
99 | True |
100 | }; |
101 | |
102 | // A node in a logical predicate tree. |
103 | struct PredNode { |
104 | PredCombinerKind kind; |
105 | const Pred *predicate; |
106 | SmallVector<PredNode *, 4> children; |
107 | std::string expr; |
108 | |
109 | // Prefix and suffix are used by ConcatPred. |
110 | std::string prefix; |
111 | std::string suffix; |
112 | }; |
113 | } // namespace |
114 | |
115 | // Get a predicate tree node kind based on the kind used in the predicate |
116 | // TableGen record. |
117 | static PredCombinerKind getPredCombinerKind(const Pred &pred) { |
118 | if (!pred.isCombined()) |
119 | return PredCombinerKind::Leaf; |
120 | |
121 | const auto &combinedPred = static_cast<const CombinedPred &>(pred); |
122 | return StringSwitch<PredCombinerKind>( |
123 | combinedPred.getCombinerDef()->getName()) |
124 | .Case(S: "PredCombinerAnd" , Value: PredCombinerKind::And) |
125 | .Case(S: "PredCombinerOr" , Value: PredCombinerKind::Or) |
126 | .Case(S: "PredCombinerNot" , Value: PredCombinerKind::Not) |
127 | .Case(S: "PredCombinerSubstLeaves" , Value: PredCombinerKind::SubstLeaves) |
128 | .Case(S: "PredCombinerConcat" , Value: PredCombinerKind::Concat); |
129 | } |
130 | |
131 | namespace { |
132 | // Substitution<pattern, replacement>. |
133 | using Subst = std::pair<StringRef, StringRef>; |
134 | } // namespace |
135 | |
136 | /// Perform the given substitutions on 'str' in-place. |
137 | static void performSubstitutions(std::string &str, |
138 | ArrayRef<Subst> substitutions) { |
139 | // Apply all parent substitutions from innermost to outermost. |
140 | for (const auto &subst : llvm::reverse(C&: substitutions)) { |
141 | auto pos = str.find(str: std::string(subst.first)); |
142 | while (pos != std::string::npos) { |
143 | str.replace(pos: pos, n: subst.first.size(), str: std::string(subst.second)); |
144 | // Skip the newly inserted substring, which itself may consider the |
145 | // pattern to match. |
146 | pos += subst.second.size(); |
147 | // Find the next possible match position. |
148 | pos = str.find(str: std::string(subst.first), pos: pos); |
149 | } |
150 | } |
151 | } |
152 | |
153 | // Build the predicate tree starting from the top-level predicate, which may |
154 | // have children, and perform leaf substitutions inplace. Note that after |
155 | // substitution, nodes are still pointing to the original TableGen record. |
156 | // All nodes are created within "allocator". |
157 | static PredNode * |
158 | buildPredicateTree(const Pred &root, |
159 | llvm::SpecificBumpPtrAllocator<PredNode> &allocator, |
160 | ArrayRef<Subst> substitutions) { |
161 | auto *rootNode = allocator.Allocate(); |
162 | new (rootNode) PredNode; |
163 | rootNode->kind = getPredCombinerKind(pred: root); |
164 | rootNode->predicate = &root; |
165 | if (!root.isCombined()) { |
166 | rootNode->expr = root.getCondition(); |
167 | performSubstitutions(str&: rootNode->expr, substitutions); |
168 | return rootNode; |
169 | } |
170 | |
171 | // If the current combined predicate is a leaf substitution, append it to the |
172 | // list before continuing. |
173 | auto allSubstitutions = llvm::to_vector<4>(Range&: substitutions); |
174 | if (rootNode->kind == PredCombinerKind::SubstLeaves) { |
175 | const auto &substPred = static_cast<const SubstLeavesPred &>(root); |
176 | allSubstitutions.push_back( |
177 | Elt: {substPred.getPattern(), substPred.getReplacement()}); |
178 | |
179 | // If the current predicate is a ConcatPred, record the prefix and suffix. |
180 | } else if (rootNode->kind == PredCombinerKind::Concat) { |
181 | const auto &concatPred = static_cast<const ConcatPred &>(root); |
182 | rootNode->prefix = std::string(concatPred.getPrefix()); |
183 | performSubstitutions(str&: rootNode->prefix, substitutions); |
184 | rootNode->suffix = std::string(concatPred.getSuffix()); |
185 | performSubstitutions(str&: rootNode->suffix, substitutions); |
186 | } |
187 | |
188 | // Build child subtrees. |
189 | auto combined = static_cast<const CombinedPred &>(root); |
190 | for (const auto *record : combined.getChildren()) { |
191 | auto *childTree = |
192 | buildPredicateTree(root: Pred(record), allocator, substitutions: allSubstitutions); |
193 | rootNode->children.push_back(Elt: childTree); |
194 | } |
195 | return rootNode; |
196 | } |
197 | |
198 | // Simplify a predicate tree rooted at "node" using the predicates that are |
199 | // known to be true(false). For AND(OR) combined predicates, if any of the |
200 | // children is known to be false(true), the result is also false(true). |
201 | // Furthermore, for AND(OR) combined predicates, children that are known to be |
202 | // true(false) don't have to be checked dynamically. |
203 | static PredNode * |
204 | propagateGroundTruth(PredNode *node, |
205 | const llvm::SmallPtrSetImpl<Pred *> &knownTruePreds, |
206 | const llvm::SmallPtrSetImpl<Pred *> &knownFalsePreds) { |
207 | // If the current predicate is known to be true or false, change the kind of |
208 | // the node and return immediately. |
209 | if (knownTruePreds.count(Ptr: node->predicate) != 0) { |
210 | node->kind = PredCombinerKind::True; |
211 | node->children.clear(); |
212 | return node; |
213 | } |
214 | if (knownFalsePreds.count(Ptr: node->predicate) != 0) { |
215 | node->kind = PredCombinerKind::False; |
216 | node->children.clear(); |
217 | return node; |
218 | } |
219 | |
220 | // If the current node is a substitution, stop recursion now. |
221 | // The expressions in the leaves below this node were rewritten, but the nodes |
222 | // still point to the original predicate records. While the original |
223 | // predicate may be known to be true or false, it is not necessarily the case |
224 | // after rewriting. |
225 | // TODO: we can support ground truth for rewritten |
226 | // predicates by either (a) having our own unique'ing of the predicates |
227 | // instead of relying on TableGen record pointers or (b) taking ground truth |
228 | // values optionally prefixed with a list of substitutions to apply, e.g. |
229 | // "predX is true by itself as well as predSubY leaf substitution had been |
230 | // applied to it". |
231 | if (node->kind == PredCombinerKind::SubstLeaves) { |
232 | return node; |
233 | } |
234 | |
235 | // Otherwise, look at child nodes. |
236 | |
237 | // Move child nodes into some local variable so that they can be optimized |
238 | // separately and re-added if necessary. |
239 | llvm::SmallVector<PredNode *, 4> children; |
240 | std::swap(LHS&: node->children, RHS&: children); |
241 | |
242 | for (auto &child : children) { |
243 | // First, simplify the child. This maintains the predicate as it was. |
244 | auto *simplifiedChild = |
245 | propagateGroundTruth(node: child, knownTruePreds, knownFalsePreds); |
246 | |
247 | // Just add the child if we don't know how to simplify the current node. |
248 | if (node->kind != PredCombinerKind::And && |
249 | node->kind != PredCombinerKind::Or) { |
250 | node->children.push_back(Elt: simplifiedChild); |
251 | continue; |
252 | } |
253 | |
254 | // Second, based on the type define which known values of child predicates |
255 | // immediately collapse this predicate to a known value, and which others |
256 | // may be safely ignored. |
257 | // OR(..., True, ...) = True |
258 | // OR(..., False, ...) = OR(..., ...) |
259 | // AND(..., False, ...) = False |
260 | // AND(..., True, ...) = AND(..., ...) |
261 | auto collapseKind = node->kind == PredCombinerKind::And |
262 | ? PredCombinerKind::False |
263 | : PredCombinerKind::True; |
264 | auto eraseKind = node->kind == PredCombinerKind::And |
265 | ? PredCombinerKind::True |
266 | : PredCombinerKind::False; |
267 | const auto &collapseList = |
268 | node->kind == PredCombinerKind::And ? knownFalsePreds : knownTruePreds; |
269 | const auto &eraseList = |
270 | node->kind == PredCombinerKind::And ? knownTruePreds : knownFalsePreds; |
271 | if (simplifiedChild->kind == collapseKind || |
272 | collapseList.count(Ptr: simplifiedChild->predicate) != 0) { |
273 | node->kind = collapseKind; |
274 | node->children.clear(); |
275 | return node; |
276 | } |
277 | if (simplifiedChild->kind == eraseKind || |
278 | eraseList.count(Ptr: simplifiedChild->predicate) != 0) { |
279 | continue; |
280 | } |
281 | node->children.push_back(Elt: simplifiedChild); |
282 | } |
283 | return node; |
284 | } |
285 | |
286 | // Combine a list of predicate expressions using a binary combiner. If a list |
287 | // is empty, return "init". |
288 | static std::string combineBinary(ArrayRef<std::string> children, |
289 | const std::string &combiner, |
290 | std::string init) { |
291 | if (children.empty()) |
292 | return init; |
293 | |
294 | auto size = children.size(); |
295 | if (size == 1) |
296 | return children.front(); |
297 | |
298 | std::string str; |
299 | llvm::raw_string_ostream os(str); |
300 | os << '(' << children.front() << ')'; |
301 | for (unsigned i = 1; i < size; ++i) { |
302 | os << ' ' << combiner << " (" << children[i] << ')'; |
303 | } |
304 | return os.str(); |
305 | } |
306 | |
307 | // Prepend negation to the only condition in the predicate expression list. |
308 | static std::string combineNot(ArrayRef<std::string> children) { |
309 | assert(children.size() == 1 && "expected exactly one child predicate of Neg" ); |
310 | return (Twine("!(" ) + children.front() + Twine(')')).str(); |
311 | } |
312 | |
313 | // Recursively traverse the predicate tree in depth-first post-order and build |
314 | // the final expression. |
315 | static std::string getCombinedCondition(const PredNode &root) { |
316 | // Immediately return for non-combiner predicates that don't have children. |
317 | if (root.kind == PredCombinerKind::Leaf) |
318 | return root.expr; |
319 | if (root.kind == PredCombinerKind::True) |
320 | return "true" ; |
321 | if (root.kind == PredCombinerKind::False) |
322 | return "false" ; |
323 | |
324 | // Recurse into children. |
325 | llvm::SmallVector<std::string, 4> childExpressions; |
326 | childExpressions.reserve(N: root.children.size()); |
327 | for (const auto &child : root.children) |
328 | childExpressions.push_back(Elt: getCombinedCondition(root: *child)); |
329 | |
330 | // Combine the expressions based on the predicate node kind. |
331 | if (root.kind == PredCombinerKind::And) |
332 | return combineBinary(children: childExpressions, combiner: "&&" , init: "true" ); |
333 | if (root.kind == PredCombinerKind::Or) |
334 | return combineBinary(children: childExpressions, combiner: "||" , init: "false" ); |
335 | if (root.kind == PredCombinerKind::Not) |
336 | return combineNot(children: childExpressions); |
337 | if (root.kind == PredCombinerKind::Concat) { |
338 | assert(childExpressions.size() == 1 && |
339 | "ConcatPred should only have one child" ); |
340 | return root.prefix + childExpressions.front() + root.suffix; |
341 | } |
342 | |
343 | // Substitutions were applied before so just ignore them. |
344 | if (root.kind == PredCombinerKind::SubstLeaves) { |
345 | assert(childExpressions.size() == 1 && |
346 | "substitution predicate must have one child" ); |
347 | return childExpressions[0]; |
348 | } |
349 | |
350 | llvm::PrintFatalError(ErrorLoc: root.predicate->getLoc(), Msg: "unsupported predicate kind" ); |
351 | } |
352 | |
353 | std::string CombinedPred::getConditionImpl() const { |
354 | llvm::SpecificBumpPtrAllocator<PredNode> allocator; |
355 | auto *predicateTree = buildPredicateTree(root: *this, allocator, substitutions: {}); |
356 | predicateTree = |
357 | propagateGroundTruth(node: predicateTree, |
358 | /*knownTruePreds=*/llvm::SmallPtrSet<Pred *, 2>(), |
359 | /*knownFalsePreds=*/llvm::SmallPtrSet<Pred *, 2>()); |
360 | |
361 | return getCombinedCondition(root: *predicateTree); |
362 | } |
363 | |
364 | StringRef SubstLeavesPred::getPattern() const { |
365 | return def->getValueAsString(FieldName: "pattern" ); |
366 | } |
367 | |
368 | StringRef SubstLeavesPred::getReplacement() const { |
369 | return def->getValueAsString(FieldName: "replacement" ); |
370 | } |
371 | |
372 | StringRef ConcatPred::getPrefix() const { |
373 | return def->getValueAsString(FieldName: "prefix" ); |
374 | } |
375 | |
376 | StringRef ConcatPred::getSuffix() const { |
377 | return def->getValueAsString(FieldName: "suffix" ); |
378 | } |
379 | |