1//===- Predicate.cpp - Predicate class ------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Wrapper around predicates defined in TableGen.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/TableGen/Predicate.h"
14#include "llvm/ADT/SmallPtrSet.h"
15#include "llvm/ADT/StringExtras.h"
16#include "llvm/ADT/StringSwitch.h"
17#include "llvm/Support/FormatVariadic.h"
18#include "llvm/TableGen/Error.h"
19#include "llvm/TableGen/Record.h"
20
21using namespace mlir;
22using namespace tblgen;
23
24// Construct a Predicate from a record.
25Pred::Pred(const llvm::Record *record) : def(record) {
26 assert(def->isSubClassOf("Pred") &&
27 "must be a subclass of TableGen 'Pred' class");
28}
29
30// Construct a Predicate from an initializer.
31Pred::Pred(const llvm::Init *init) {
32 if (const auto *defInit = dyn_cast_or_null<llvm::DefInit>(Val: init))
33 def = defInit->getDef();
34}
35
36std::string Pred::getCondition() const {
37 // Static dispatch to subclasses.
38 if (def->isSubClassOf(Name: "CombinedPred"))
39 return static_cast<const CombinedPred *>(this)->getConditionImpl();
40 if (def->isSubClassOf(Name: "CPred"))
41 return static_cast<const CPred *>(this)->getConditionImpl();
42 llvm_unreachable("Pred::getCondition must be overridden in subclasses");
43}
44
45bool Pred::isCombined() const {
46 return def && def->isSubClassOf(Name: "CombinedPred");
47}
48
49ArrayRef<SMLoc> Pred::getLoc() const { return def->getLoc(); }
50
51CPred::CPred(const llvm::Record *record) : Pred(record) {
52 assert(def->isSubClassOf("CPred") &&
53 "must be a subclass of Tablegen 'CPred' class");
54}
55
56CPred::CPred(const llvm::Init *init) : Pred(init) {
57 assert((!def || def->isSubClassOf("CPred")) &&
58 "must be a subclass of Tablegen 'CPred' class");
59}
60
61// Get condition of the C Predicate.
62std::string CPred::getConditionImpl() const {
63 assert(!isNull() && "null predicate does not have a condition");
64 return std::string(def->getValueAsString(FieldName: "predExpr"));
65}
66
67CombinedPred::CombinedPred(const llvm::Record *record) : Pred(record) {
68 assert(def->isSubClassOf("CombinedPred") &&
69 "must be a subclass of Tablegen 'CombinedPred' class");
70}
71
72CombinedPred::CombinedPred(const llvm::Init *init) : Pred(init) {
73 assert((!def || def->isSubClassOf("CombinedPred")) &&
74 "must be a subclass of Tablegen 'CombinedPred' class");
75}
76
77const llvm::Record *CombinedPred::getCombinerDef() const {
78 assert(def->getValue("kind") && "CombinedPred must have a value 'kind'");
79 return def->getValueAsDef(FieldName: "kind");
80}
81
82std::vector<llvm::Record *> CombinedPred::getChildren() const {
83 assert(def->getValue("children") &&
84 "CombinedPred must have a value 'children'");
85 return def->getValueAsListOfDefs(FieldName: "children");
86}
87
88namespace {
89// Kinds of nodes in a logical predicate tree.
90enum class PredCombinerKind {
91 Leaf,
92 And,
93 Or,
94 Not,
95 SubstLeaves,
96 Concat,
97 // Special kinds that are used in simplification.
98 False,
99 True
100};
101
102// A node in a logical predicate tree.
103struct PredNode {
104 PredCombinerKind kind;
105 const Pred *predicate;
106 SmallVector<PredNode *, 4> children;
107 std::string expr;
108
109 // Prefix and suffix are used by ConcatPred.
110 std::string prefix;
111 std::string suffix;
112};
113} // namespace
114
115// Get a predicate tree node kind based on the kind used in the predicate
116// TableGen record.
117static PredCombinerKind getPredCombinerKind(const Pred &pred) {
118 if (!pred.isCombined())
119 return PredCombinerKind::Leaf;
120
121 const auto &combinedPred = static_cast<const CombinedPred &>(pred);
122 return StringSwitch<PredCombinerKind>(
123 combinedPred.getCombinerDef()->getName())
124 .Case(S: "PredCombinerAnd", Value: PredCombinerKind::And)
125 .Case(S: "PredCombinerOr", Value: PredCombinerKind::Or)
126 .Case(S: "PredCombinerNot", Value: PredCombinerKind::Not)
127 .Case(S: "PredCombinerSubstLeaves", Value: PredCombinerKind::SubstLeaves)
128 .Case(S: "PredCombinerConcat", Value: PredCombinerKind::Concat);
129}
130
131namespace {
132// Substitution<pattern, replacement>.
133using Subst = std::pair<StringRef, StringRef>;
134} // namespace
135
136/// Perform the given substitutions on 'str' in-place.
137static void performSubstitutions(std::string &str,
138 ArrayRef<Subst> substitutions) {
139 // Apply all parent substitutions from innermost to outermost.
140 for (const auto &subst : llvm::reverse(C&: substitutions)) {
141 auto pos = str.find(str: std::string(subst.first));
142 while (pos != std::string::npos) {
143 str.replace(pos: pos, n: subst.first.size(), str: std::string(subst.second));
144 // Skip the newly inserted substring, which itself may consider the
145 // pattern to match.
146 pos += subst.second.size();
147 // Find the next possible match position.
148 pos = str.find(str: std::string(subst.first), pos: pos);
149 }
150 }
151}
152
153// Build the predicate tree starting from the top-level predicate, which may
154// have children, and perform leaf substitutions inplace. Note that after
155// substitution, nodes are still pointing to the original TableGen record.
156// All nodes are created within "allocator".
157static PredNode *
158buildPredicateTree(const Pred &root,
159 llvm::SpecificBumpPtrAllocator<PredNode> &allocator,
160 ArrayRef<Subst> substitutions) {
161 auto *rootNode = allocator.Allocate();
162 new (rootNode) PredNode;
163 rootNode->kind = getPredCombinerKind(pred: root);
164 rootNode->predicate = &root;
165 if (!root.isCombined()) {
166 rootNode->expr = root.getCondition();
167 performSubstitutions(str&: rootNode->expr, substitutions);
168 return rootNode;
169 }
170
171 // If the current combined predicate is a leaf substitution, append it to the
172 // list before continuing.
173 auto allSubstitutions = llvm::to_vector<4>(Range&: substitutions);
174 if (rootNode->kind == PredCombinerKind::SubstLeaves) {
175 const auto &substPred = static_cast<const SubstLeavesPred &>(root);
176 allSubstitutions.push_back(
177 Elt: {substPred.getPattern(), substPred.getReplacement()});
178
179 // If the current predicate is a ConcatPred, record the prefix and suffix.
180 } else if (rootNode->kind == PredCombinerKind::Concat) {
181 const auto &concatPred = static_cast<const ConcatPred &>(root);
182 rootNode->prefix = std::string(concatPred.getPrefix());
183 performSubstitutions(str&: rootNode->prefix, substitutions);
184 rootNode->suffix = std::string(concatPred.getSuffix());
185 performSubstitutions(str&: rootNode->suffix, substitutions);
186 }
187
188 // Build child subtrees.
189 auto combined = static_cast<const CombinedPred &>(root);
190 for (const auto *record : combined.getChildren()) {
191 auto *childTree =
192 buildPredicateTree(root: Pred(record), allocator, substitutions: allSubstitutions);
193 rootNode->children.push_back(Elt: childTree);
194 }
195 return rootNode;
196}
197
198// Simplify a predicate tree rooted at "node" using the predicates that are
199// known to be true(false). For AND(OR) combined predicates, if any of the
200// children is known to be false(true), the result is also false(true).
201// Furthermore, for AND(OR) combined predicates, children that are known to be
202// true(false) don't have to be checked dynamically.
203static PredNode *
204propagateGroundTruth(PredNode *node,
205 const llvm::SmallPtrSetImpl<Pred *> &knownTruePreds,
206 const llvm::SmallPtrSetImpl<Pred *> &knownFalsePreds) {
207 // If the current predicate is known to be true or false, change the kind of
208 // the node and return immediately.
209 if (knownTruePreds.count(Ptr: node->predicate) != 0) {
210 node->kind = PredCombinerKind::True;
211 node->children.clear();
212 return node;
213 }
214 if (knownFalsePreds.count(Ptr: node->predicate) != 0) {
215 node->kind = PredCombinerKind::False;
216 node->children.clear();
217 return node;
218 }
219
220 // If the current node is a substitution, stop recursion now.
221 // The expressions in the leaves below this node were rewritten, but the nodes
222 // still point to the original predicate records. While the original
223 // predicate may be known to be true or false, it is not necessarily the case
224 // after rewriting.
225 // TODO: we can support ground truth for rewritten
226 // predicates by either (a) having our own unique'ing of the predicates
227 // instead of relying on TableGen record pointers or (b) taking ground truth
228 // values optionally prefixed with a list of substitutions to apply, e.g.
229 // "predX is true by itself as well as predSubY leaf substitution had been
230 // applied to it".
231 if (node->kind == PredCombinerKind::SubstLeaves) {
232 return node;
233 }
234
235 // Otherwise, look at child nodes.
236
237 // Move child nodes into some local variable so that they can be optimized
238 // separately and re-added if necessary.
239 llvm::SmallVector<PredNode *, 4> children;
240 std::swap(LHS&: node->children, RHS&: children);
241
242 for (auto &child : children) {
243 // First, simplify the child. This maintains the predicate as it was.
244 auto *simplifiedChild =
245 propagateGroundTruth(node: child, knownTruePreds, knownFalsePreds);
246
247 // Just add the child if we don't know how to simplify the current node.
248 if (node->kind != PredCombinerKind::And &&
249 node->kind != PredCombinerKind::Or) {
250 node->children.push_back(Elt: simplifiedChild);
251 continue;
252 }
253
254 // Second, based on the type define which known values of child predicates
255 // immediately collapse this predicate to a known value, and which others
256 // may be safely ignored.
257 // OR(..., True, ...) = True
258 // OR(..., False, ...) = OR(..., ...)
259 // AND(..., False, ...) = False
260 // AND(..., True, ...) = AND(..., ...)
261 auto collapseKind = node->kind == PredCombinerKind::And
262 ? PredCombinerKind::False
263 : PredCombinerKind::True;
264 auto eraseKind = node->kind == PredCombinerKind::And
265 ? PredCombinerKind::True
266 : PredCombinerKind::False;
267 const auto &collapseList =
268 node->kind == PredCombinerKind::And ? knownFalsePreds : knownTruePreds;
269 const auto &eraseList =
270 node->kind == PredCombinerKind::And ? knownTruePreds : knownFalsePreds;
271 if (simplifiedChild->kind == collapseKind ||
272 collapseList.count(Ptr: simplifiedChild->predicate) != 0) {
273 node->kind = collapseKind;
274 node->children.clear();
275 return node;
276 }
277 if (simplifiedChild->kind == eraseKind ||
278 eraseList.count(Ptr: simplifiedChild->predicate) != 0) {
279 continue;
280 }
281 node->children.push_back(Elt: simplifiedChild);
282 }
283 return node;
284}
285
286// Combine a list of predicate expressions using a binary combiner. If a list
287// is empty, return "init".
288static std::string combineBinary(ArrayRef<std::string> children,
289 const std::string &combiner,
290 std::string init) {
291 if (children.empty())
292 return init;
293
294 auto size = children.size();
295 if (size == 1)
296 return children.front();
297
298 std::string str;
299 llvm::raw_string_ostream os(str);
300 os << '(' << children.front() << ')';
301 for (unsigned i = 1; i < size; ++i) {
302 os << ' ' << combiner << " (" << children[i] << ')';
303 }
304 return os.str();
305}
306
307// Prepend negation to the only condition in the predicate expression list.
308static std::string combineNot(ArrayRef<std::string> children) {
309 assert(children.size() == 1 && "expected exactly one child predicate of Neg");
310 return (Twine("!(") + children.front() + Twine(')')).str();
311}
312
313// Recursively traverse the predicate tree in depth-first post-order and build
314// the final expression.
315static std::string getCombinedCondition(const PredNode &root) {
316 // Immediately return for non-combiner predicates that don't have children.
317 if (root.kind == PredCombinerKind::Leaf)
318 return root.expr;
319 if (root.kind == PredCombinerKind::True)
320 return "true";
321 if (root.kind == PredCombinerKind::False)
322 return "false";
323
324 // Recurse into children.
325 llvm::SmallVector<std::string, 4> childExpressions;
326 childExpressions.reserve(N: root.children.size());
327 for (const auto &child : root.children)
328 childExpressions.push_back(Elt: getCombinedCondition(root: *child));
329
330 // Combine the expressions based on the predicate node kind.
331 if (root.kind == PredCombinerKind::And)
332 return combineBinary(children: childExpressions, combiner: "&&", init: "true");
333 if (root.kind == PredCombinerKind::Or)
334 return combineBinary(children: childExpressions, combiner: "||", init: "false");
335 if (root.kind == PredCombinerKind::Not)
336 return combineNot(children: childExpressions);
337 if (root.kind == PredCombinerKind::Concat) {
338 assert(childExpressions.size() == 1 &&
339 "ConcatPred should only have one child");
340 return root.prefix + childExpressions.front() + root.suffix;
341 }
342
343 // Substitutions were applied before so just ignore them.
344 if (root.kind == PredCombinerKind::SubstLeaves) {
345 assert(childExpressions.size() == 1 &&
346 "substitution predicate must have one child");
347 return childExpressions[0];
348 }
349
350 llvm::PrintFatalError(ErrorLoc: root.predicate->getLoc(), Msg: "unsupported predicate kind");
351}
352
353std::string CombinedPred::getConditionImpl() const {
354 llvm::SpecificBumpPtrAllocator<PredNode> allocator;
355 auto *predicateTree = buildPredicateTree(root: *this, allocator, substitutions: {});
356 predicateTree =
357 propagateGroundTruth(node: predicateTree,
358 /*knownTruePreds=*/llvm::SmallPtrSet<Pred *, 2>(),
359 /*knownFalsePreds=*/llvm::SmallPtrSet<Pred *, 2>());
360
361 return getCombinedCondition(root: *predicateTree);
362}
363
364StringRef SubstLeavesPred::getPattern() const {
365 return def->getValueAsString(FieldName: "pattern");
366}
367
368StringRef SubstLeavesPred::getReplacement() const {
369 return def->getValueAsString(FieldName: "replacement");
370}
371
372StringRef ConcatPred::getPrefix() const {
373 return def->getValueAsString(FieldName: "prefix");
374}
375
376StringRef ConcatPred::getSuffix() const {
377 return def->getValueAsString(FieldName: "suffix");
378}
379

source code of mlir/lib/TableGen/Predicate.cpp