1//===- PDL.cpp - Pattern Descriptor Language Dialect ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/PDL/IR/PDL.h"
10#include "mlir/Dialect/PDL/IR/PDLOps.h"
11#include "mlir/Dialect/PDL/IR/PDLTypes.h"
12#include "mlir/IR/BuiltinTypes.h"
13#include "mlir/Interfaces/InferTypeOpInterface.h"
14#include "llvm/ADT/DenseSet.h"
15#include "llvm/ADT/TypeSwitch.h"
16#include <optional>
17
18using namespace mlir;
19using namespace mlir::pdl;
20
21#include "mlir/Dialect/PDL/IR/PDLOpsDialect.cpp.inc"
22
23//===----------------------------------------------------------------------===//
24// PDLDialect
25//===----------------------------------------------------------------------===//
26
27void PDLDialect::initialize() {
28 addOperations<
29#define GET_OP_LIST
30#include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc"
31 >();
32 registerTypes();
33}
34
35//===----------------------------------------------------------------------===//
36// PDL Operations
37//===----------------------------------------------------------------------===//
38
39/// Returns true if the given operation is used by a "binding" pdl operation.
40static bool hasBindingUse(Operation *op) {
41 for (Operation *user : op->getUsers())
42 // A result by itself is not binding, it must also be bound.
43 if (!isa<ResultOp, ResultsOp>(user) || hasBindingUse(user))
44 return true;
45 return false;
46}
47
48/// Returns success if the given operation is not in the main matcher body or
49/// is used by a "binding" operation. On failure, emits an error.
50static LogicalResult verifyHasBindingUse(Operation *op) {
51 // If the parent is not a pattern, there is nothing to do.
52 if (!llvm::isa_and_nonnull<PatternOp>(op->getParentOp()))
53 return success();
54 if (hasBindingUse(op))
55 return success();
56 return op->emitOpError(
57 message: "expected a bindable user when defined in the matcher body of a "
58 "`pdl.pattern`");
59}
60
61/// Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s)
62/// connected to the given operation.
63static void visit(Operation *op, DenseSet<Operation *> &visited) {
64 // If the parent is not a pattern, there is nothing to do.
65 if (!isa<PatternOp>(op->getParentOp()) || isa<RewriteOp>(op))
66 return;
67
68 // Ignore if already visited. Otherwise, mark as visited.
69 if (!visited.insert(V: op).second)
70 return;
71
72 // Traverse the operands / parent.
73 TypeSwitch<Operation *>(op)
74 .Case<OperationOp>([&visited](auto operation) {
75 for (Value operand : operation.getOperandValues())
76 visit(operand.getDefiningOp(), visited);
77 })
78 .Case<ResultOp, ResultsOp>([&visited](auto result) {
79 visit(result.getParent().getDefiningOp(), visited);
80 });
81
82 // Traverse the users.
83 for (Operation *user : op->getUsers())
84 visit(op: user, visited);
85}
86
87//===----------------------------------------------------------------------===//
88// pdl::ApplyNativeConstraintOp
89//===----------------------------------------------------------------------===//
90
91LogicalResult ApplyNativeConstraintOp::verify() {
92 if (getNumOperands() == 0)
93 return emitOpError("expected at least one argument");
94 if (llvm::any_of(getResults(), [](OpResult result) {
95 return isa<OperationType>(result.getType());
96 })) {
97 return emitOpError(
98 "returning an operation from a constraint is not supported");
99 }
100 return success();
101}
102
103//===----------------------------------------------------------------------===//
104// pdl::ApplyNativeRewriteOp
105//===----------------------------------------------------------------------===//
106
107LogicalResult ApplyNativeRewriteOp::verify() {
108 if (getNumOperands() == 0 && getNumResults() == 0)
109 return emitOpError("expected at least one argument or result");
110 return success();
111}
112
113//===----------------------------------------------------------------------===//
114// pdl::AttributeOp
115//===----------------------------------------------------------------------===//
116
117LogicalResult AttributeOp::verify() {
118 Value attrType = getValueType();
119 std::optional<Attribute> attrValue = getValue();
120
121 if (!attrValue) {
122 if (isa<RewriteOp>((*this)->getParentOp()))
123 return emitOpError(
124 "expected constant value when specified within a `pdl.rewrite`");
125 return verifyHasBindingUse(*this);
126 }
127 if (attrType)
128 return emitOpError("expected only one of [`type`, `value`] to be set");
129 return success();
130}
131
132//===----------------------------------------------------------------------===//
133// pdl::OperandOp
134//===----------------------------------------------------------------------===//
135
136LogicalResult OperandOp::verify() { return verifyHasBindingUse(*this); }
137
138//===----------------------------------------------------------------------===//
139// pdl::OperandsOp
140//===----------------------------------------------------------------------===//
141
142LogicalResult OperandsOp::verify() { return verifyHasBindingUse(*this); }
143
144//===----------------------------------------------------------------------===//
145// pdl::OperationOp
146//===----------------------------------------------------------------------===//
147
148static ParseResult parseOperationOpAttributes(
149 OpAsmParser &p,
150 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &attrOperands,
151 ArrayAttr &attrNamesAttr) {
152 Builder &builder = p.getBuilder();
153 SmallVector<Attribute, 4> attrNames;
154 if (succeeded(Result: p.parseOptionalLBrace())) {
155 auto parseOperands = [&]() {
156 StringAttr nameAttr;
157 OpAsmParser::UnresolvedOperand operand;
158 if (p.parseAttribute(nameAttr) || p.parseEqual() ||
159 p.parseOperand(result&: operand))
160 return failure();
161 attrNames.push_back(Elt: nameAttr);
162 attrOperands.push_back(Elt: operand);
163 return success();
164 };
165 if (p.parseCommaSeparatedList(parseElementFn: parseOperands) || p.parseRBrace())
166 return failure();
167 }
168 attrNamesAttr = builder.getArrayAttr(attrNames);
169 return success();
170}
171
172static void printOperationOpAttributes(OpAsmPrinter &p, OperationOp op,
173 OperandRange attrArgs,
174 ArrayAttr attrNames) {
175 if (attrNames.empty())
176 return;
177 p << " {";
178 interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
179 [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; });
180 p << '}';
181}
182
183/// Verifies that the result types of this operation, defined within a
184/// `pdl.rewrite`, can be inferred.
185static LogicalResult verifyResultTypesAreInferrable(OperationOp op,
186 OperandRange resultTypes) {
187 // Functor that returns if the given use can be used to infer a type.
188 Block *rewriterBlock = op->getBlock();
189 auto canInferTypeFromUse = [&](OpOperand &use) {
190 // If the use is within a ReplaceOp and isn't the operation being replaced
191 // (i.e. is not the first operand of the replacement), we can infer a type.
192 ReplaceOp replOpUser = dyn_cast<ReplaceOp>(use.getOwner());
193 if (!replOpUser || use.getOperandNumber() == 0)
194 return false;
195 // Make sure the replaced operation was defined before this one.
196 Operation *replacedOp = replOpUser.getOpValue().getDefiningOp();
197 return replacedOp->getBlock() != rewriterBlock ||
198 replacedOp->isBeforeInBlock(op);
199 };
200
201 // Check to see if the uses of the operation itself can be used to infer
202 // types.
203 if (llvm::any_of(op.getOp().getUses(), canInferTypeFromUse))
204 return success();
205
206 // Handle the case where the operation has no explicit result types.
207 if (resultTypes.empty()) {
208 // If we don't know the concrete operation, don't attempt any verification.
209 // We can't make assumptions if we don't know the concrete operation.
210 std::optional<StringRef> rawOpName = op.getOpName();
211 if (!rawOpName)
212 return success();
213 std::optional<RegisteredOperationName> opName =
214 RegisteredOperationName::lookup(*rawOpName, op.getContext());
215 if (!opName)
216 return success();
217
218 // If no explicit result types were provided, check to see if the operation
219 // expected at least one result. This doesn't cover all cases, but this
220 // should cover many cases in which the user intended to infer the results
221 // of an operation, but it isn't actually possible.
222 bool expectedAtLeastOneResult =
223 !opName->hasTrait<OpTrait::ZeroResults>() &&
224 !opName->hasTrait<OpTrait::VariadicResults>();
225 if (expectedAtLeastOneResult) {
226 return op
227 .emitOpError("must have inferable or constrained result types when "
228 "nested within `pdl.rewrite`")
229 .attachNote()
230 .append("operation is created in a non-inferrable context, but '",
231 *opName, "' does not implement InferTypeOpInterface");
232 }
233 return success();
234 }
235
236 // Otherwise, make sure each of the types can be inferred.
237 for (const auto &it : llvm::enumerate(First&: resultTypes)) {
238 Operation *resultTypeOp = it.value().getDefiningOp();
239 assert(resultTypeOp && "expected valid result type operation");
240
241 // If the op was defined by a `apply_native_rewrite`, it is guaranteed to be
242 // usable.
243 if (isa<ApplyNativeRewriteOp>(resultTypeOp))
244 continue;
245
246 // If the type operation was defined in the matcher and constrains an
247 // operand or the result of an input operation, it can be used.
248 auto constrainsInput = [rewriterBlock](Operation *user) {
249 return user->getBlock() != rewriterBlock &&
250 isa<OperandOp, OperandsOp, OperationOp>(user);
251 };
252 if (TypeOp typeOp = dyn_cast<TypeOp>(resultTypeOp)) {
253 if (typeOp.getConstantType() ||
254 llvm::any_of(typeOp->getUsers(), constrainsInput))
255 continue;
256 } else if (TypesOp typeOp = dyn_cast<TypesOp>(resultTypeOp)) {
257 if (typeOp.getConstantTypes() ||
258 llvm::any_of(typeOp->getUsers(), constrainsInput))
259 continue;
260 }
261
262 return op
263 .emitOpError("must have inferable or constrained result types when "
264 "nested within `pdl.rewrite`")
265 .attachNote()
266 .append("result type #", it.index(), " was not constrained");
267 }
268 return success();
269}
270
271LogicalResult OperationOp::verify() {
272 bool isWithinRewrite = isa_and_nonnull<RewriteOp>((*this)->getParentOp());
273 if (isWithinRewrite && !getOpName())
274 return emitOpError("must have an operation name when nested within "
275 "a `pdl.rewrite`");
276 ArrayAttr attributeNames = getAttributeValueNamesAttr();
277 auto attributeValues = getAttributeValues();
278 if (attributeNames.size() != attributeValues.size()) {
279 return emitOpError()
280 << "expected the same number of attribute values and attribute "
281 "names, got "
282 << attributeNames.size() << " names and " << attributeValues.size()
283 << " values";
284 }
285
286 // If the operation is within a rewrite body and doesn't have type inference,
287 // ensure that the result types can be resolved.
288 if (isWithinRewrite && !mightHaveTypeInference()) {
289 if (failed(verifyResultTypesAreInferrable(*this, getTypeValues())))
290 return failure();
291 }
292
293 return verifyHasBindingUse(*this);
294}
295
296bool OperationOp::hasTypeInference() {
297 if (std::optional<StringRef> rawOpName = getOpName()) {
298 OperationName opName(*rawOpName, getContext());
299 return opName.hasInterface<InferTypeOpInterface>();
300 }
301 return false;
302}
303
304bool OperationOp::mightHaveTypeInference() {
305 if (std::optional<StringRef> rawOpName = getOpName()) {
306 OperationName opName(*rawOpName, getContext());
307 return opName.mightHaveInterface<InferTypeOpInterface>();
308 }
309 return false;
310}
311
312//===----------------------------------------------------------------------===//
313// pdl::PatternOp
314//===----------------------------------------------------------------------===//
315
316LogicalResult PatternOp::verifyRegions() {
317 Region &body = getBodyRegion();
318 Operation *term = body.front().getTerminator();
319 auto rewriteOp = dyn_cast<RewriteOp>(term);
320 if (!rewriteOp) {
321 return emitOpError("expected body to terminate with `pdl.rewrite`")
322 .attachNote(term->getLoc())
323 .append("see terminator defined here");
324 }
325
326 // Check that all values defined in the top-level pattern belong to the PDL
327 // dialect.
328 WalkResult result = body.walk([&](Operation *op) -> WalkResult {
329 if (!isa_and_nonnull<PDLDialect>(op->getDialect())) {
330 emitOpError("expected only `pdl` operations within the pattern body")
331 .attachNote(op->getLoc())
332 .append("see non-`pdl` operation defined here");
333 return WalkResult::interrupt();
334 }
335 return WalkResult::advance();
336 });
337 if (result.wasInterrupted())
338 return failure();
339
340 // Check that there is at least one operation.
341 if (body.front().getOps<OperationOp>().empty())
342 return emitOpError("the pattern must contain at least one `pdl.operation`");
343
344 // Determine if the operations within the pdl.pattern form a connected
345 // component. This is determined by starting the search from the first
346 // operand/result/operation and visiting their users / parents / operands.
347 // We limit our attention to operations that have a user in pdl.rewrite,
348 // those that do not will be detected via other means (expected bindable
349 // user).
350 bool first = true;
351 DenseSet<Operation *> visited;
352 for (Operation &op : body.front()) {
353 // The following are the operations forming the connected component.
354 if (!isa<OperandOp, OperandsOp, ResultOp, ResultsOp, OperationOp>(op))
355 continue;
356
357 // Determine if the operation has a user in `pdl.rewrite`.
358 bool hasUserInRewrite = false;
359 for (Operation *user : op.getUsers()) {
360 Region *region = user->getParentRegion();
361 if (isa<RewriteOp>(user) ||
362 (region && isa<RewriteOp>(region->getParentOp()))) {
363 hasUserInRewrite = true;
364 break;
365 }
366 }
367
368 // If the operation does not have a user in `pdl.rewrite`, ignore it.
369 if (!hasUserInRewrite)
370 continue;
371
372 if (first) {
373 // For the first operation, invoke visit.
374 visit(&op, visited);
375 first = false;
376 } else if (!visited.count(&op)) {
377 // For the subsequent operations, check if already visited.
378 return emitOpError("the operations must form a connected component")
379 .attachNote(op.getLoc())
380 .append("see a disconnected value / operation here");
381 }
382 }
383
384 return success();
385}
386
387void PatternOp::build(OpBuilder &builder, OperationState &state,
388 std::optional<uint16_t> benefit,
389 std::optional<StringRef> name) {
390 build(builder, state, builder.getI16IntegerAttr(benefit.value_or(0)),
391 name ? builder.getStringAttr(*name) : StringAttr());
392 state.regions[0]->emplaceBlock();
393}
394
395/// Returns the rewrite operation of this pattern.
396RewriteOp PatternOp::getRewriter() {
397 return cast<RewriteOp>(getBodyRegion().front().getTerminator());
398}
399
400/// The default dialect is `pdl`.
401StringRef PatternOp::getDefaultDialect() {
402 return PDLDialect::getDialectNamespace();
403}
404
405//===----------------------------------------------------------------------===//
406// pdl::RangeOp
407//===----------------------------------------------------------------------===//
408
409static ParseResult parseRangeType(OpAsmParser &p, TypeRange argumentTypes,
410 Type &resultType) {
411 // If arguments were provided, infer the result type from the argument list.
412 if (!argumentTypes.empty()) {
413 resultType = RangeType::get(getRangeElementTypeOrSelf(argumentTypes[0]));
414 return success();
415 }
416 // Otherwise, parse the type as a trailing type.
417 return p.parseColonType(result&: resultType);
418}
419
420static void printRangeType(OpAsmPrinter &p, RangeOp op, TypeRange argumentTypes,
421 Type resultType) {
422 if (argumentTypes.empty())
423 p << ": " << resultType;
424}
425
426LogicalResult RangeOp::verify() {
427 Type elementType = getType().getElementType();
428 for (Type operandType : getOperandTypes()) {
429 Type operandElementType = getRangeElementTypeOrSelf(operandType);
430 if (operandElementType != elementType) {
431 return emitOpError("expected operand to have element type ")
432 << elementType << ", but got " << operandElementType;
433 }
434 }
435 return success();
436}
437
438//===----------------------------------------------------------------------===//
439// pdl::ReplaceOp
440//===----------------------------------------------------------------------===//
441
442LogicalResult ReplaceOp::verify() {
443 if (getReplOperation() && !getReplValues().empty())
444 return emitOpError() << "expected no replacement values to be provided"
445 " when the replacement operation is present";
446 return success();
447}
448
449//===----------------------------------------------------------------------===//
450// pdl::ResultsOp
451//===----------------------------------------------------------------------===//
452
453static ParseResult parseResultsValueType(OpAsmParser &p, IntegerAttr index,
454 Type &resultType) {
455 if (!index) {
456 resultType = RangeType::get(p.getBuilder().getType<ValueType>());
457 return success();
458 }
459 if (p.parseArrow() || p.parseType(result&: resultType))
460 return failure();
461 return success();
462}
463
464static void printResultsValueType(OpAsmPrinter &p, ResultsOp op,
465 IntegerAttr index, Type resultType) {
466 if (index)
467 p << " -> " << resultType;
468}
469
470LogicalResult ResultsOp::verify() {
471 if (!getIndex() && llvm::isa<pdl::ValueType>(getType())) {
472 return emitOpError() << "expected `pdl.range<value>` result type when "
473 "no index is specified, but got: "
474 << getType();
475 }
476 return success();
477}
478
479//===----------------------------------------------------------------------===//
480// pdl::RewriteOp
481//===----------------------------------------------------------------------===//
482
483LogicalResult RewriteOp::verifyRegions() {
484 Region &rewriteRegion = getBodyRegion();
485
486 // Handle the case where the rewrite is external.
487 if (getName()) {
488 if (!rewriteRegion.empty()) {
489 return emitOpError()
490 << "expected rewrite region to be empty when rewrite is external";
491 }
492 return success();
493 }
494
495 // Otherwise, check that the rewrite region only contains a single block.
496 if (rewriteRegion.empty()) {
497 return emitOpError() << "expected rewrite region to be non-empty if "
498 "external name is not specified";
499 }
500
501 // Check that no additional arguments were provided.
502 if (!getExternalArgs().empty()) {
503 return emitOpError() << "expected no external arguments when the "
504 "rewrite is specified inline";
505 }
506
507 return success();
508}
509
510/// The default dialect is `pdl`.
511StringRef RewriteOp::getDefaultDialect() {
512 return PDLDialect::getDialectNamespace();
513}
514
515//===----------------------------------------------------------------------===//
516// pdl::TypeOp
517//===----------------------------------------------------------------------===//
518
519LogicalResult TypeOp::verify() {
520 if (!getConstantTypeAttr())
521 return verifyHasBindingUse(*this);
522 return success();
523}
524
525//===----------------------------------------------------------------------===//
526// pdl::TypesOp
527//===----------------------------------------------------------------------===//
528
529LogicalResult TypesOp::verify() {
530 if (!getConstantTypesAttr())
531 return verifyHasBindingUse(*this);
532 return success();
533}
534
535//===----------------------------------------------------------------------===//
536// TableGen'd op method definitions
537//===----------------------------------------------------------------------===//
538
539#define GET_OP_CLASSES
540#include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc"
541

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/lib/Dialect/PDL/IR/PDL.cpp