1 | //===- Predicate.cpp - Predicate class ------------------------------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Wrapper around predicates defined in TableGen. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/TableGen/Predicate.h" |
14 | #include "llvm/ADT/SmallPtrSet.h" |
15 | #include "llvm/ADT/StringExtras.h" |
16 | #include "llvm/ADT/StringSwitch.h" |
17 | #include "llvm/Support/FormatVariadic.h" |
18 | #include "llvm/TableGen/Error.h" |
19 | #include "llvm/TableGen/Record.h" |
20 | |
21 | using namespace mlir; |
22 | using namespace tblgen; |
23 | using llvm::Init; |
24 | using llvm::Record; |
25 | using llvm::SpecificBumpPtrAllocator; |
26 | |
27 | // Construct a Predicate from a record. |
28 | Pred::Pred(const Record *record) : def(record) { |
29 | assert(def->isSubClassOf("Pred") && |
30 | "must be a subclass of TableGen 'Pred' class"); |
31 | } |
32 | |
33 | // Construct a Predicate from an initializer. |
34 | Pred::Pred(const Init *init) { |
35 | if (const auto *defInit = dyn_cast_or_null<llvm::DefInit>(Val: init)) |
36 | def = defInit->getDef(); |
37 | } |
38 | |
39 | std::string Pred::getCondition() const { |
40 | // Static dispatch to subclasses. |
41 | if (def->isSubClassOf(Name: "CombinedPred")) |
42 | return static_cast<const CombinedPred *>(this)->getConditionImpl(); |
43 | if (def->isSubClassOf(Name: "CPred")) |
44 | return static_cast<const CPred *>(this)->getConditionImpl(); |
45 | llvm_unreachable("Pred::getCondition must be overridden in subclasses"); |
46 | } |
47 | |
48 | bool Pred::isCombined() const { |
49 | return def && def->isSubClassOf(Name: "CombinedPred"); |
50 | } |
51 | |
52 | ArrayRef<SMLoc> Pred::getLoc() const { return def->getLoc(); } |
53 | |
54 | CPred::CPred(const Record *record) : Pred(record) { |
55 | assert(def->isSubClassOf("CPred") && |
56 | "must be a subclass of Tablegen 'CPred' class"); |
57 | } |
58 | |
59 | CPred::CPred(const Init *init) : Pred(init) { |
60 | assert((!def || def->isSubClassOf("CPred")) && |
61 | "must be a subclass of Tablegen 'CPred' class"); |
62 | } |
63 | |
64 | // Get condition of the C Predicate. |
65 | std::string CPred::getConditionImpl() const { |
66 | assert(!isNull() && "null predicate does not have a condition"); |
67 | return std::string(def->getValueAsString(FieldName: "predExpr")); |
68 | } |
69 | |
70 | CombinedPred::CombinedPred(const Record *record) : Pred(record) { |
71 | assert(def->isSubClassOf("CombinedPred") && |
72 | "must be a subclass of Tablegen 'CombinedPred' class"); |
73 | } |
74 | |
75 | CombinedPred::CombinedPred(const Init *init) : Pred(init) { |
76 | assert((!def || def->isSubClassOf("CombinedPred")) && |
77 | "must be a subclass of Tablegen 'CombinedPred' class"); |
78 | } |
79 | |
80 | const Record *CombinedPred::getCombinerDef() const { |
81 | assert(def->getValue("kind") && "CombinedPred must have a value 'kind'"); |
82 | return def->getValueAsDef(FieldName: "kind"); |
83 | } |
84 | |
85 | std::vector<const Record *> CombinedPred::getChildren() const { |
86 | assert(def->getValue("children") && |
87 | "CombinedPred must have a value 'children'"); |
88 | return def->getValueAsListOfDefs(FieldName: "children"); |
89 | } |
90 | |
91 | namespace { |
92 | // Kinds of nodes in a logical predicate tree. |
93 | enum class PredCombinerKind { |
94 | Leaf, |
95 | And, |
96 | Or, |
97 | Not, |
98 | SubstLeaves, |
99 | Concat, |
100 | // Special kinds that are used in simplification. |
101 | False, |
102 | True |
103 | }; |
104 | |
105 | // A node in a logical predicate tree. |
106 | struct PredNode { |
107 | PredCombinerKind kind; |
108 | const Pred *predicate; |
109 | SmallVector<PredNode *, 4> children; |
110 | std::string expr; |
111 | |
112 | // Prefix and suffix are used by ConcatPred. |
113 | std::string prefix; |
114 | std::string suffix; |
115 | }; |
116 | } // namespace |
117 | |
118 | // Get a predicate tree node kind based on the kind used in the predicate |
119 | // TableGen record. |
120 | static PredCombinerKind getPredCombinerKind(const Pred &pred) { |
121 | if (!pred.isCombined()) |
122 | return PredCombinerKind::Leaf; |
123 | |
124 | const auto &combinedPred = static_cast<const CombinedPred &>(pred); |
125 | return StringSwitch<PredCombinerKind>( |
126 | combinedPred.getCombinerDef()->getName()) |
127 | .Case(S: "PredCombinerAnd", Value: PredCombinerKind::And) |
128 | .Case(S: "PredCombinerOr", Value: PredCombinerKind::Or) |
129 | .Case(S: "PredCombinerNot", Value: PredCombinerKind::Not) |
130 | .Case(S: "PredCombinerSubstLeaves", Value: PredCombinerKind::SubstLeaves) |
131 | .Case(S: "PredCombinerConcat", Value: PredCombinerKind::Concat); |
132 | } |
133 | |
134 | namespace { |
135 | // Substitution<pattern, replacement>. |
136 | using Subst = std::pair<StringRef, StringRef>; |
137 | } // namespace |
138 | |
139 | /// Perform the given substitutions on 'str' in-place. |
140 | static void performSubstitutions(std::string &str, |
141 | ArrayRef<Subst> substitutions) { |
142 | // Apply all parent substitutions from innermost to outermost. |
143 | for (const auto &subst : llvm::reverse(C&: substitutions)) { |
144 | auto pos = str.find(svt: subst.first); |
145 | while (pos != std::string::npos) { |
146 | str.replace(pos: pos, n: subst.first.size(), str: std::string(subst.second)); |
147 | // Skip the newly inserted substring, which itself may consider the |
148 | // pattern to match. |
149 | pos += subst.second.size(); |
150 | // Find the next possible match position. |
151 | pos = str.find(svt: subst.first, pos: pos); |
152 | } |
153 | } |
154 | } |
155 | |
156 | // Build the predicate tree starting from the top-level predicate, which may |
157 | // have children, and perform leaf substitutions inplace. Note that after |
158 | // substitution, nodes are still pointing to the original TableGen record. |
159 | // All nodes are created within "allocator". |
160 | static PredNode * |
161 | buildPredicateTree(const Pred &root, |
162 | SpecificBumpPtrAllocator<PredNode> &allocator, |
163 | ArrayRef<Subst> substitutions) { |
164 | auto *rootNode = allocator.Allocate(); |
165 | new (rootNode) PredNode; |
166 | rootNode->kind = getPredCombinerKind(pred: root); |
167 | rootNode->predicate = &root; |
168 | if (!root.isCombined()) { |
169 | rootNode->expr = root.getCondition(); |
170 | performSubstitutions(str&: rootNode->expr, substitutions); |
171 | return rootNode; |
172 | } |
173 | |
174 | // If the current combined predicate is a leaf substitution, append it to the |
175 | // list before continuing. |
176 | auto allSubstitutions = llvm::to_vector<4>(Range&: substitutions); |
177 | if (rootNode->kind == PredCombinerKind::SubstLeaves) { |
178 | const auto &substPred = static_cast<const SubstLeavesPred &>(root); |
179 | allSubstitutions.push_back( |
180 | Elt: {substPred.getPattern(), substPred.getReplacement()}); |
181 | |
182 | // If the current predicate is a ConcatPred, record the prefix and suffix. |
183 | } else if (rootNode->kind == PredCombinerKind::Concat) { |
184 | const auto &concatPred = static_cast<const ConcatPred &>(root); |
185 | rootNode->prefix = std::string(concatPred.getPrefix()); |
186 | performSubstitutions(str&: rootNode->prefix, substitutions); |
187 | rootNode->suffix = std::string(concatPred.getSuffix()); |
188 | performSubstitutions(str&: rootNode->suffix, substitutions); |
189 | } |
190 | |
191 | // Build child subtrees. |
192 | auto combined = static_cast<const CombinedPred &>(root); |
193 | for (const auto *record : combined.getChildren()) { |
194 | auto *childTree = |
195 | buildPredicateTree(root: Pred(record), allocator, substitutions: allSubstitutions); |
196 | rootNode->children.push_back(Elt: childTree); |
197 | } |
198 | return rootNode; |
199 | } |
200 | |
201 | // Simplify a predicate tree rooted at "node" using the predicates that are |
202 | // known to be true(false). For AND(OR) combined predicates, if any of the |
203 | // children is known to be false(true), the result is also false(true). |
204 | // Furthermore, for AND(OR) combined predicates, children that are known to be |
205 | // true(false) don't have to be checked dynamically. |
206 | static PredNode * |
207 | propagateGroundTruth(PredNode *node, |
208 | const llvm::SmallPtrSetImpl<Pred *> &knownTruePreds, |
209 | const llvm::SmallPtrSetImpl<Pred *> &knownFalsePreds) { |
210 | // If the current predicate is known to be true or false, change the kind of |
211 | // the node and return immediately. |
212 | if (knownTruePreds.count(Ptr: node->predicate) != 0) { |
213 | node->kind = PredCombinerKind::True; |
214 | node->children.clear(); |
215 | return node; |
216 | } |
217 | if (knownFalsePreds.count(Ptr: node->predicate) != 0) { |
218 | node->kind = PredCombinerKind::False; |
219 | node->children.clear(); |
220 | return node; |
221 | } |
222 | |
223 | // If the current node is a substitution, stop recursion now. |
224 | // The expressions in the leaves below this node were rewritten, but the nodes |
225 | // still point to the original predicate records. While the original |
226 | // predicate may be known to be true or false, it is not necessarily the case |
227 | // after rewriting. |
228 | // TODO: we can support ground truth for rewritten |
229 | // predicates by either (a) having our own unique'ing of the predicates |
230 | // instead of relying on TableGen record pointers or (b) taking ground truth |
231 | // values optionally prefixed with a list of substitutions to apply, e.g. |
232 | // "predX is true by itself as well as predSubY leaf substitution had been |
233 | // applied to it". |
234 | if (node->kind == PredCombinerKind::SubstLeaves) { |
235 | return node; |
236 | } |
237 | |
238 | if (node->kind == PredCombinerKind::And && node->children.empty()) { |
239 | node->kind = PredCombinerKind::True; |
240 | return node; |
241 | } |
242 | |
243 | if (node->kind == PredCombinerKind::Or && node->children.empty()) { |
244 | node->kind = PredCombinerKind::False; |
245 | return node; |
246 | } |
247 | |
248 | // Otherwise, look at child nodes. |
249 | |
250 | // Move child nodes into some local variable so that they can be optimized |
251 | // separately and re-added if necessary. |
252 | llvm::SmallVector<PredNode *, 4> children; |
253 | std::swap(LHS&: node->children, RHS&: children); |
254 | |
255 | for (auto &child : children) { |
256 | // First, simplify the child. This maintains the predicate as it was. |
257 | auto *simplifiedChild = |
258 | propagateGroundTruth(node: child, knownTruePreds, knownFalsePreds); |
259 | |
260 | // Just add the child if we don't know how to simplify the current node. |
261 | if (node->kind != PredCombinerKind::And && |
262 | node->kind != PredCombinerKind::Or) { |
263 | node->children.push_back(Elt: simplifiedChild); |
264 | continue; |
265 | } |
266 | |
267 | // Second, based on the type define which known values of child predicates |
268 | // immediately collapse this predicate to a known value, and which others |
269 | // may be safely ignored. |
270 | // OR(..., True, ...) = True |
271 | // OR(..., False, ...) = OR(..., ...) |
272 | // AND(..., False, ...) = False |
273 | // AND(..., True, ...) = AND(..., ...) |
274 | auto collapseKind = node->kind == PredCombinerKind::And |
275 | ? PredCombinerKind::False |
276 | : PredCombinerKind::True; |
277 | auto eraseKind = node->kind == PredCombinerKind::And |
278 | ? PredCombinerKind::True |
279 | : PredCombinerKind::False; |
280 | const auto &collapseList = |
281 | node->kind == PredCombinerKind::And ? knownFalsePreds : knownTruePreds; |
282 | const auto &eraseList = |
283 | node->kind == PredCombinerKind::And ? knownTruePreds : knownFalsePreds; |
284 | if (simplifiedChild->kind == collapseKind || |
285 | collapseList.count(Ptr: simplifiedChild->predicate) != 0) { |
286 | node->kind = collapseKind; |
287 | node->children.clear(); |
288 | return node; |
289 | } |
290 | if (simplifiedChild->kind == eraseKind || |
291 | eraseList.count(Ptr: simplifiedChild->predicate) != 0) { |
292 | continue; |
293 | } |
294 | node->children.push_back(Elt: simplifiedChild); |
295 | } |
296 | return node; |
297 | } |
298 | |
299 | // Combine a list of predicate expressions using a binary combiner. If a list |
300 | // is empty, return "init". |
301 | static std::string combineBinary(ArrayRef<std::string> children, |
302 | const std::string &combiner, |
303 | std::string init) { |
304 | if (children.empty()) |
305 | return init; |
306 | |
307 | auto size = children.size(); |
308 | if (size == 1) |
309 | return children.front(); |
310 | |
311 | std::string str; |
312 | llvm::raw_string_ostream os(str); |
313 | os << '(' << children.front() << ')'; |
314 | for (unsigned i = 1; i < size; ++i) { |
315 | os << ' ' << combiner << " ("<< children[i] << ')'; |
316 | } |
317 | return str; |
318 | } |
319 | |
320 | // Prepend negation to the only condition in the predicate expression list. |
321 | static std::string combineNot(ArrayRef<std::string> children) { |
322 | assert(children.size() == 1 && "expected exactly one child predicate of Neg"); |
323 | return (Twine("!(") + children.front() + Twine(')')).str(); |
324 | } |
325 | |
326 | // Recursively traverse the predicate tree in depth-first post-order and build |
327 | // the final expression. |
328 | static std::string getCombinedCondition(const PredNode &root) { |
329 | // Immediately return for non-combiner predicates that don't have children. |
330 | if (root.kind == PredCombinerKind::Leaf) |
331 | return root.expr; |
332 | if (root.kind == PredCombinerKind::True) |
333 | return "true"; |
334 | if (root.kind == PredCombinerKind::False) |
335 | return "false"; |
336 | |
337 | // Recurse into children. |
338 | llvm::SmallVector<std::string, 4> childExpressions; |
339 | childExpressions.reserve(N: root.children.size()); |
340 | for (const auto &child : root.children) |
341 | childExpressions.push_back(Elt: getCombinedCondition(root: *child)); |
342 | |
343 | // Combine the expressions based on the predicate node kind. |
344 | if (root.kind == PredCombinerKind::And) |
345 | return combineBinary(children: childExpressions, combiner: "&&", init: "true"); |
346 | if (root.kind == PredCombinerKind::Or) |
347 | return combineBinary(children: childExpressions, combiner: "||", init: "false"); |
348 | if (root.kind == PredCombinerKind::Not) |
349 | return combineNot(children: childExpressions); |
350 | if (root.kind == PredCombinerKind::Concat) { |
351 | assert(childExpressions.size() == 1 && |
352 | "ConcatPred should only have one child"); |
353 | return root.prefix + childExpressions.front() + root.suffix; |
354 | } |
355 | |
356 | // Substitutions were applied before so just ignore them. |
357 | if (root.kind == PredCombinerKind::SubstLeaves) { |
358 | assert(childExpressions.size() == 1 && |
359 | "substitution predicate must have one child"); |
360 | return childExpressions[0]; |
361 | } |
362 | |
363 | llvm::PrintFatalError(ErrorLoc: root.predicate->getLoc(), Msg: "unsupported predicate kind"); |
364 | } |
365 | |
366 | std::string CombinedPred::getConditionImpl() const { |
367 | SpecificBumpPtrAllocator<PredNode> allocator; |
368 | auto *predicateTree = buildPredicateTree(root: *this, allocator, substitutions: {}); |
369 | predicateTree = |
370 | propagateGroundTruth(node: predicateTree, |
371 | /*knownTruePreds=*/llvm::SmallPtrSet<Pred *, 2>(), |
372 | /*knownFalsePreds=*/llvm::SmallPtrSet<Pred *, 2>()); |
373 | |
374 | return getCombinedCondition(root: *predicateTree); |
375 | } |
376 | |
377 | StringRef SubstLeavesPred::getPattern() const { |
378 | return def->getValueAsString(FieldName: "pattern"); |
379 | } |
380 | |
381 | StringRef SubstLeavesPred::getReplacement() const { |
382 | return def->getValueAsString(FieldName: "replacement"); |
383 | } |
384 | |
385 | StringRef ConcatPred::getPrefix() const { |
386 | return def->getValueAsString(FieldName: "prefix"); |
387 | } |
388 | |
389 | StringRef ConcatPred::getSuffix() const { |
390 | return def->getValueAsString(FieldName: "suffix"); |
391 | } |
392 |
Definitions
- Pred
- Pred
- getCondition
- isCombined
- getLoc
- CPred
- CPred
- getConditionImpl
- CombinedPred
- CombinedPred
- getCombinerDef
- getChildren
- PredCombinerKind
- PredNode
- getPredCombinerKind
- performSubstitutions
- buildPredicateTree
- propagateGroundTruth
- combineBinary
- combineNot
- getCombinedCondition
- getConditionImpl
- getPattern
- getReplacement
- getPrefix
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more