1//===- Predicate.cpp - Predicate class ------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Wrapper around predicates defined in TableGen.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/TableGen/Predicate.h"
14#include "llvm/ADT/SmallPtrSet.h"
15#include "llvm/ADT/StringExtras.h"
16#include "llvm/ADT/StringSwitch.h"
17#include "llvm/Support/FormatVariadic.h"
18#include "llvm/TableGen/Error.h"
19#include "llvm/TableGen/Record.h"
20
21using namespace mlir;
22using namespace tblgen;
23using llvm::Init;
24using llvm::Record;
25using llvm::SpecificBumpPtrAllocator;
26
27// Construct a Predicate from a record.
28Pred::Pred(const Record *record) : def(record) {
29 assert(def->isSubClassOf("Pred") &&
30 "must be a subclass of TableGen 'Pred' class");
31}
32
33// Construct a Predicate from an initializer.
34Pred::Pred(const Init *init) {
35 if (const auto *defInit = dyn_cast_or_null<llvm::DefInit>(Val: init))
36 def = defInit->getDef();
37}
38
39std::string Pred::getCondition() const {
40 // Static dispatch to subclasses.
41 if (def->isSubClassOf(Name: "CombinedPred"))
42 return static_cast<const CombinedPred *>(this)->getConditionImpl();
43 if (def->isSubClassOf(Name: "CPred"))
44 return static_cast<const CPred *>(this)->getConditionImpl();
45 llvm_unreachable("Pred::getCondition must be overridden in subclasses");
46}
47
48bool Pred::isCombined() const {
49 return def && def->isSubClassOf(Name: "CombinedPred");
50}
51
52ArrayRef<SMLoc> Pred::getLoc() const { return def->getLoc(); }
53
54CPred::CPred(const Record *record) : Pred(record) {
55 assert(def->isSubClassOf("CPred") &&
56 "must be a subclass of Tablegen 'CPred' class");
57}
58
59CPred::CPred(const Init *init) : Pred(init) {
60 assert((!def || def->isSubClassOf("CPred")) &&
61 "must be a subclass of Tablegen 'CPred' class");
62}
63
64// Get condition of the C Predicate.
65std::string CPred::getConditionImpl() const {
66 assert(!isNull() && "null predicate does not have a condition");
67 return std::string(def->getValueAsString(FieldName: "predExpr"));
68}
69
70CombinedPred::CombinedPred(const Record *record) : Pred(record) {
71 assert(def->isSubClassOf("CombinedPred") &&
72 "must be a subclass of Tablegen 'CombinedPred' class");
73}
74
75CombinedPred::CombinedPred(const Init *init) : Pred(init) {
76 assert((!def || def->isSubClassOf("CombinedPred")) &&
77 "must be a subclass of Tablegen 'CombinedPred' class");
78}
79
80const Record *CombinedPred::getCombinerDef() const {
81 assert(def->getValue("kind") && "CombinedPred must have a value 'kind'");
82 return def->getValueAsDef(FieldName: "kind");
83}
84
85std::vector<const Record *> CombinedPred::getChildren() const {
86 assert(def->getValue("children") &&
87 "CombinedPred must have a value 'children'");
88 return def->getValueAsListOfDefs(FieldName: "children");
89}
90
91namespace {
92// Kinds of nodes in a logical predicate tree.
93enum class PredCombinerKind {
94 Leaf,
95 And,
96 Or,
97 Not,
98 SubstLeaves,
99 Concat,
100 // Special kinds that are used in simplification.
101 False,
102 True
103};
104
105// A node in a logical predicate tree.
106struct PredNode {
107 PredCombinerKind kind;
108 const Pred *predicate;
109 SmallVector<PredNode *, 4> children;
110 std::string expr;
111
112 // Prefix and suffix are used by ConcatPred.
113 std::string prefix;
114 std::string suffix;
115};
116} // namespace
117
118// Get a predicate tree node kind based on the kind used in the predicate
119// TableGen record.
120static PredCombinerKind getPredCombinerKind(const Pred &pred) {
121 if (!pred.isCombined())
122 return PredCombinerKind::Leaf;
123
124 const auto &combinedPred = static_cast<const CombinedPred &>(pred);
125 return StringSwitch<PredCombinerKind>(
126 combinedPred.getCombinerDef()->getName())
127 .Case(S: "PredCombinerAnd", Value: PredCombinerKind::And)
128 .Case(S: "PredCombinerOr", Value: PredCombinerKind::Or)
129 .Case(S: "PredCombinerNot", Value: PredCombinerKind::Not)
130 .Case(S: "PredCombinerSubstLeaves", Value: PredCombinerKind::SubstLeaves)
131 .Case(S: "PredCombinerConcat", Value: PredCombinerKind::Concat);
132}
133
134namespace {
135// Substitution<pattern, replacement>.
136using Subst = std::pair<StringRef, StringRef>;
137} // namespace
138
139/// Perform the given substitutions on 'str' in-place.
140static void performSubstitutions(std::string &str,
141 ArrayRef<Subst> substitutions) {
142 // Apply all parent substitutions from innermost to outermost.
143 for (const auto &subst : llvm::reverse(C&: substitutions)) {
144 auto pos = str.find(svt: subst.first);
145 while (pos != std::string::npos) {
146 str.replace(pos: pos, n: subst.first.size(), str: std::string(subst.second));
147 // Skip the newly inserted substring, which itself may consider the
148 // pattern to match.
149 pos += subst.second.size();
150 // Find the next possible match position.
151 pos = str.find(svt: subst.first, pos: pos);
152 }
153 }
154}
155
156// Build the predicate tree starting from the top-level predicate, which may
157// have children, and perform leaf substitutions inplace. Note that after
158// substitution, nodes are still pointing to the original TableGen record.
159// All nodes are created within "allocator".
160static PredNode *
161buildPredicateTree(const Pred &root,
162 SpecificBumpPtrAllocator<PredNode> &allocator,
163 ArrayRef<Subst> substitutions) {
164 auto *rootNode = allocator.Allocate();
165 new (rootNode) PredNode;
166 rootNode->kind = getPredCombinerKind(pred: root);
167 rootNode->predicate = &root;
168 if (!root.isCombined()) {
169 rootNode->expr = root.getCondition();
170 performSubstitutions(str&: rootNode->expr, substitutions);
171 return rootNode;
172 }
173
174 // If the current combined predicate is a leaf substitution, append it to the
175 // list before continuing.
176 auto allSubstitutions = llvm::to_vector<4>(Range&: substitutions);
177 if (rootNode->kind == PredCombinerKind::SubstLeaves) {
178 const auto &substPred = static_cast<const SubstLeavesPred &>(root);
179 allSubstitutions.push_back(
180 Elt: {substPred.getPattern(), substPred.getReplacement()});
181
182 // If the current predicate is a ConcatPred, record the prefix and suffix.
183 } else if (rootNode->kind == PredCombinerKind::Concat) {
184 const auto &concatPred = static_cast<const ConcatPred &>(root);
185 rootNode->prefix = std::string(concatPred.getPrefix());
186 performSubstitutions(str&: rootNode->prefix, substitutions);
187 rootNode->suffix = std::string(concatPred.getSuffix());
188 performSubstitutions(str&: rootNode->suffix, substitutions);
189 }
190
191 // Build child subtrees.
192 auto combined = static_cast<const CombinedPred &>(root);
193 for (const auto *record : combined.getChildren()) {
194 auto *childTree =
195 buildPredicateTree(root: Pred(record), allocator, substitutions: allSubstitutions);
196 rootNode->children.push_back(Elt: childTree);
197 }
198 return rootNode;
199}
200
201// Simplify a predicate tree rooted at "node" using the predicates that are
202// known to be true(false). For AND(OR) combined predicates, if any of the
203// children is known to be false(true), the result is also false(true).
204// Furthermore, for AND(OR) combined predicates, children that are known to be
205// true(false) don't have to be checked dynamically.
206static PredNode *
207propagateGroundTruth(PredNode *node,
208 const llvm::SmallPtrSetImpl<Pred *> &knownTruePreds,
209 const llvm::SmallPtrSetImpl<Pred *> &knownFalsePreds) {
210 // If the current predicate is known to be true or false, change the kind of
211 // the node and return immediately.
212 if (knownTruePreds.count(Ptr: node->predicate) != 0) {
213 node->kind = PredCombinerKind::True;
214 node->children.clear();
215 return node;
216 }
217 if (knownFalsePreds.count(Ptr: node->predicate) != 0) {
218 node->kind = PredCombinerKind::False;
219 node->children.clear();
220 return node;
221 }
222
223 // If the current node is a substitution, stop recursion now.
224 // The expressions in the leaves below this node were rewritten, but the nodes
225 // still point to the original predicate records. While the original
226 // predicate may be known to be true or false, it is not necessarily the case
227 // after rewriting.
228 // TODO: we can support ground truth for rewritten
229 // predicates by either (a) having our own unique'ing of the predicates
230 // instead of relying on TableGen record pointers or (b) taking ground truth
231 // values optionally prefixed with a list of substitutions to apply, e.g.
232 // "predX is true by itself as well as predSubY leaf substitution had been
233 // applied to it".
234 if (node->kind == PredCombinerKind::SubstLeaves) {
235 return node;
236 }
237
238 if (node->kind == PredCombinerKind::And && node->children.empty()) {
239 node->kind = PredCombinerKind::True;
240 return node;
241 }
242
243 if (node->kind == PredCombinerKind::Or && node->children.empty()) {
244 node->kind = PredCombinerKind::False;
245 return node;
246 }
247
248 // Otherwise, look at child nodes.
249
250 // Move child nodes into some local variable so that they can be optimized
251 // separately and re-added if necessary.
252 llvm::SmallVector<PredNode *, 4> children;
253 std::swap(LHS&: node->children, RHS&: children);
254
255 for (auto &child : children) {
256 // First, simplify the child. This maintains the predicate as it was.
257 auto *simplifiedChild =
258 propagateGroundTruth(node: child, knownTruePreds, knownFalsePreds);
259
260 // Just add the child if we don't know how to simplify the current node.
261 if (node->kind != PredCombinerKind::And &&
262 node->kind != PredCombinerKind::Or) {
263 node->children.push_back(Elt: simplifiedChild);
264 continue;
265 }
266
267 // Second, based on the type define which known values of child predicates
268 // immediately collapse this predicate to a known value, and which others
269 // may be safely ignored.
270 // OR(..., True, ...) = True
271 // OR(..., False, ...) = OR(..., ...)
272 // AND(..., False, ...) = False
273 // AND(..., True, ...) = AND(..., ...)
274 auto collapseKind = node->kind == PredCombinerKind::And
275 ? PredCombinerKind::False
276 : PredCombinerKind::True;
277 auto eraseKind = node->kind == PredCombinerKind::And
278 ? PredCombinerKind::True
279 : PredCombinerKind::False;
280 const auto &collapseList =
281 node->kind == PredCombinerKind::And ? knownFalsePreds : knownTruePreds;
282 const auto &eraseList =
283 node->kind == PredCombinerKind::And ? knownTruePreds : knownFalsePreds;
284 if (simplifiedChild->kind == collapseKind ||
285 collapseList.count(Ptr: simplifiedChild->predicate) != 0) {
286 node->kind = collapseKind;
287 node->children.clear();
288 return node;
289 }
290 if (simplifiedChild->kind == eraseKind ||
291 eraseList.count(Ptr: simplifiedChild->predicate) != 0) {
292 continue;
293 }
294 node->children.push_back(Elt: simplifiedChild);
295 }
296 return node;
297}
298
299// Combine a list of predicate expressions using a binary combiner. If a list
300// is empty, return "init".
301static std::string combineBinary(ArrayRef<std::string> children,
302 const std::string &combiner,
303 std::string init) {
304 if (children.empty())
305 return init;
306
307 auto size = children.size();
308 if (size == 1)
309 return children.front();
310
311 std::string str;
312 llvm::raw_string_ostream os(str);
313 os << '(' << children.front() << ')';
314 for (unsigned i = 1; i < size; ++i) {
315 os << ' ' << combiner << " (" << children[i] << ')';
316 }
317 return str;
318}
319
320// Prepend negation to the only condition in the predicate expression list.
321static std::string combineNot(ArrayRef<std::string> children) {
322 assert(children.size() == 1 && "expected exactly one child predicate of Neg");
323 return (Twine("!(") + children.front() + Twine(')')).str();
324}
325
326// Recursively traverse the predicate tree in depth-first post-order and build
327// the final expression.
328static std::string getCombinedCondition(const PredNode &root) {
329 // Immediately return for non-combiner predicates that don't have children.
330 if (root.kind == PredCombinerKind::Leaf)
331 return root.expr;
332 if (root.kind == PredCombinerKind::True)
333 return "true";
334 if (root.kind == PredCombinerKind::False)
335 return "false";
336
337 // Recurse into children.
338 llvm::SmallVector<std::string, 4> childExpressions;
339 childExpressions.reserve(N: root.children.size());
340 for (const auto &child : root.children)
341 childExpressions.push_back(Elt: getCombinedCondition(root: *child));
342
343 // Combine the expressions based on the predicate node kind.
344 if (root.kind == PredCombinerKind::And)
345 return combineBinary(children: childExpressions, combiner: "&&", init: "true");
346 if (root.kind == PredCombinerKind::Or)
347 return combineBinary(children: childExpressions, combiner: "||", init: "false");
348 if (root.kind == PredCombinerKind::Not)
349 return combineNot(children: childExpressions);
350 if (root.kind == PredCombinerKind::Concat) {
351 assert(childExpressions.size() == 1 &&
352 "ConcatPred should only have one child");
353 return root.prefix + childExpressions.front() + root.suffix;
354 }
355
356 // Substitutions were applied before so just ignore them.
357 if (root.kind == PredCombinerKind::SubstLeaves) {
358 assert(childExpressions.size() == 1 &&
359 "substitution predicate must have one child");
360 return childExpressions[0];
361 }
362
363 llvm::PrintFatalError(ErrorLoc: root.predicate->getLoc(), Msg: "unsupported predicate kind");
364}
365
366std::string CombinedPred::getConditionImpl() const {
367 SpecificBumpPtrAllocator<PredNode> allocator;
368 auto *predicateTree = buildPredicateTree(root: *this, allocator, substitutions: {});
369 predicateTree =
370 propagateGroundTruth(node: predicateTree,
371 /*knownTruePreds=*/llvm::SmallPtrSet<Pred *, 2>(),
372 /*knownFalsePreds=*/llvm::SmallPtrSet<Pred *, 2>());
373
374 return getCombinedCondition(root: *predicateTree);
375}
376
377StringRef SubstLeavesPred::getPattern() const {
378 return def->getValueAsString(FieldName: "pattern");
379}
380
381StringRef SubstLeavesPred::getReplacement() const {
382 return def->getValueAsString(FieldName: "replacement");
383}
384
385StringRef ConcatPred::getPrefix() const {
386 return def->getValueAsString(FieldName: "prefix");
387}
388
389StringRef ConcatPred::getSuffix() const {
390 return def->getValueAsString(FieldName: "suffix");
391}
392

Provided by KDAB

Privacy Policy
Update your C++ knowledge – Modern C++11/14/17 Training
Find out more

source code of mlir/lib/TableGen/Predicate.cpp