1//===- PDL.cpp - Pattern Descriptor Language Dialect ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/PDL/IR/PDL.h"
10#include "mlir/Dialect/PDL/IR/PDLOps.h"
11#include "mlir/Dialect/PDL/IR/PDLTypes.h"
12#include "mlir/IR/BuiltinTypes.h"
13#include "mlir/Interfaces/InferTypeOpInterface.h"
14#include "llvm/ADT/DenseSet.h"
15#include "llvm/ADT/TypeSwitch.h"
16#include <optional>
17
18using namespace mlir;
19using namespace mlir::pdl;
20
21#include "mlir/Dialect/PDL/IR/PDLOpsDialect.cpp.inc"
22
23//===----------------------------------------------------------------------===//
24// PDLDialect
25//===----------------------------------------------------------------------===//
26
27void PDLDialect::initialize() {
28 addOperations<
29#define GET_OP_LIST
30#include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc"
31 >();
32 registerTypes();
33}
34
35//===----------------------------------------------------------------------===//
36// PDL Operations
37//===----------------------------------------------------------------------===//
38
39/// Returns true if the given operation is used by a "binding" pdl operation.
40static bool hasBindingUse(Operation *op) {
41 for (Operation *user : op->getUsers())
42 // A result by itself is not binding, it must also be bound.
43 if (!isa<ResultOp, ResultsOp>(user) || hasBindingUse(user))
44 return true;
45 return false;
46}
47
48/// Returns success if the given operation is not in the main matcher body or
49/// is used by a "binding" operation. On failure, emits an error.
50static LogicalResult verifyHasBindingUse(Operation *op) {
51 // If the parent is not a pattern, there is nothing to do.
52 if (!llvm::isa_and_nonnull<PatternOp>(Val: op->getParentOp()))
53 return success();
54 if (hasBindingUse(op))
55 return success();
56 return op->emitOpError(
57 message: "expected a bindable user when defined in the matcher body of a "
58 "`pdl.pattern`");
59}
60
61/// Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s)
62/// connected to the given operation.
63static void visit(Operation *op, DenseSet<Operation *> &visited) {
64 // If the parent is not a pattern, there is nothing to do.
65 if (!isa<PatternOp>(op->getParentOp()) || isa<RewriteOp>(op))
66 return;
67
68 // Ignore if already visited.
69 if (visited.contains(V: op))
70 return;
71
72 // Mark as visited.
73 visited.insert(V: op);
74
75 // Traverse the operands / parent.
76 TypeSwitch<Operation *>(op)
77 .Case<OperationOp>([&visited](auto operation) {
78 for (Value operand : operation.getOperandValues())
79 visit(operand.getDefiningOp(), visited);
80 })
81 .Case<ResultOp, ResultsOp>([&visited](auto result) {
82 visit(result.getParent().getDefiningOp(), visited);
83 });
84
85 // Traverse the users.
86 for (Operation *user : op->getUsers())
87 visit(op: user, visited);
88}
89
90//===----------------------------------------------------------------------===//
91// pdl::ApplyNativeConstraintOp
92//===----------------------------------------------------------------------===//
93
94LogicalResult ApplyNativeConstraintOp::verify() {
95 if (getNumOperands() == 0)
96 return emitOpError("expected at least one argument");
97 if (llvm::any_of(getResults(), [](OpResult result) {
98 return isa<OperationType>(result.getType());
99 })) {
100 return emitOpError(
101 "returning an operation from a constraint is not supported");
102 }
103 return success();
104}
105
106//===----------------------------------------------------------------------===//
107// pdl::ApplyNativeRewriteOp
108//===----------------------------------------------------------------------===//
109
110LogicalResult ApplyNativeRewriteOp::verify() {
111 if (getNumOperands() == 0 && getNumResults() == 0)
112 return emitOpError("expected at least one argument or result");
113 return success();
114}
115
116//===----------------------------------------------------------------------===//
117// pdl::AttributeOp
118//===----------------------------------------------------------------------===//
119
120LogicalResult AttributeOp::verify() {
121 Value attrType = getValueType();
122 std::optional<Attribute> attrValue = getValue();
123
124 if (!attrValue) {
125 if (isa<RewriteOp>((*this)->getParentOp()))
126 return emitOpError(
127 "expected constant value when specified within a `pdl.rewrite`");
128 return verifyHasBindingUse(*this);
129 }
130 if (attrType)
131 return emitOpError("expected only one of [`type`, `value`] to be set");
132 return success();
133}
134
135//===----------------------------------------------------------------------===//
136// pdl::OperandOp
137//===----------------------------------------------------------------------===//
138
139LogicalResult OperandOp::verify() { return verifyHasBindingUse(*this); }
140
141//===----------------------------------------------------------------------===//
142// pdl::OperandsOp
143//===----------------------------------------------------------------------===//
144
145LogicalResult OperandsOp::verify() { return verifyHasBindingUse(*this); }
146
147//===----------------------------------------------------------------------===//
148// pdl::OperationOp
149//===----------------------------------------------------------------------===//
150
151static ParseResult parseOperationOpAttributes(
152 OpAsmParser &p,
153 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &attrOperands,
154 ArrayAttr &attrNamesAttr) {
155 Builder &builder = p.getBuilder();
156 SmallVector<Attribute, 4> attrNames;
157 if (succeeded(result: p.parseOptionalLBrace())) {
158 auto parseOperands = [&]() {
159 StringAttr nameAttr;
160 OpAsmParser::UnresolvedOperand operand;
161 if (p.parseAttribute(nameAttr) || p.parseEqual() ||
162 p.parseOperand(result&: operand))
163 return failure();
164 attrNames.push_back(Elt: nameAttr);
165 attrOperands.push_back(Elt: operand);
166 return success();
167 };
168 if (p.parseCommaSeparatedList(parseElementFn: parseOperands) || p.parseRBrace())
169 return failure();
170 }
171 attrNamesAttr = builder.getArrayAttr(attrNames);
172 return success();
173}
174
175static void printOperationOpAttributes(OpAsmPrinter &p, OperationOp op,
176 OperandRange attrArgs,
177 ArrayAttr attrNames) {
178 if (attrNames.empty())
179 return;
180 p << " {";
181 interleaveComma(llvm::seq<int>(0, attrNames.size()), p,
182 [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; });
183 p << '}';
184}
185
186/// Verifies that the result types of this operation, defined within a
187/// `pdl.rewrite`, can be inferred.
188static LogicalResult verifyResultTypesAreInferrable(OperationOp op,
189 OperandRange resultTypes) {
190 // Functor that returns if the given use can be used to infer a type.
191 Block *rewriterBlock = op->getBlock();
192 auto canInferTypeFromUse = [&](OpOperand &use) {
193 // If the use is within a ReplaceOp and isn't the operation being replaced
194 // (i.e. is not the first operand of the replacement), we can infer a type.
195 ReplaceOp replOpUser = dyn_cast<ReplaceOp>(use.getOwner());
196 if (!replOpUser || use.getOperandNumber() == 0)
197 return false;
198 // Make sure the replaced operation was defined before this one.
199 Operation *replacedOp = replOpUser.getOpValue().getDefiningOp();
200 return replacedOp->getBlock() != rewriterBlock ||
201 replacedOp->isBeforeInBlock(op);
202 };
203
204 // Check to see if the uses of the operation itself can be used to infer
205 // types.
206 if (llvm::any_of(op.getOp().getUses(), canInferTypeFromUse))
207 return success();
208
209 // Handle the case where the operation has no explicit result types.
210 if (resultTypes.empty()) {
211 // If we don't know the concrete operation, don't attempt any verification.
212 // We can't make assumptions if we don't know the concrete operation.
213 std::optional<StringRef> rawOpName = op.getOpName();
214 if (!rawOpName)
215 return success();
216 std::optional<RegisteredOperationName> opName =
217 RegisteredOperationName::lookup(*rawOpName, op.getContext());
218 if (!opName)
219 return success();
220
221 // If no explicit result types were provided, check to see if the operation
222 // expected at least one result. This doesn't cover all cases, but this
223 // should cover many cases in which the user intended to infer the results
224 // of an operation, but it isn't actually possible.
225 bool expectedAtLeastOneResult =
226 !opName->hasTrait<OpTrait::ZeroResults>() &&
227 !opName->hasTrait<OpTrait::VariadicResults>();
228 if (expectedAtLeastOneResult) {
229 return op
230 .emitOpError("must have inferable or constrained result types when "
231 "nested within `pdl.rewrite`")
232 .attachNote()
233 .append("operation is created in a non-inferrable context, but '",
234 *opName, "' does not implement InferTypeOpInterface");
235 }
236 return success();
237 }
238
239 // Otherwise, make sure each of the types can be inferred.
240 for (const auto &it : llvm::enumerate(First&: resultTypes)) {
241 Operation *resultTypeOp = it.value().getDefiningOp();
242 assert(resultTypeOp && "expected valid result type operation");
243
244 // If the op was defined by a `apply_native_rewrite`, it is guaranteed to be
245 // usable.
246 if (isa<ApplyNativeRewriteOp>(resultTypeOp))
247 continue;
248
249 // If the type operation was defined in the matcher and constrains an
250 // operand or the result of an input operation, it can be used.
251 auto constrainsInput = [rewriterBlock](Operation *user) {
252 return user->getBlock() != rewriterBlock &&
253 isa<OperandOp, OperandsOp, OperationOp>(user);
254 };
255 if (TypeOp typeOp = dyn_cast<TypeOp>(resultTypeOp)) {
256 if (typeOp.getConstantType() ||
257 llvm::any_of(typeOp->getUsers(), constrainsInput))
258 continue;
259 } else if (TypesOp typeOp = dyn_cast<TypesOp>(resultTypeOp)) {
260 if (typeOp.getConstantTypes() ||
261 llvm::any_of(typeOp->getUsers(), constrainsInput))
262 continue;
263 }
264
265 return op
266 .emitOpError("must have inferable or constrained result types when "
267 "nested within `pdl.rewrite`")
268 .attachNote()
269 .append("result type #", it.index(), " was not constrained");
270 }
271 return success();
272}
273
274LogicalResult OperationOp::verify() {
275 bool isWithinRewrite = isa_and_nonnull<RewriteOp>((*this)->getParentOp());
276 if (isWithinRewrite && !getOpName())
277 return emitOpError("must have an operation name when nested within "
278 "a `pdl.rewrite`");
279 ArrayAttr attributeNames = getAttributeValueNamesAttr();
280 auto attributeValues = getAttributeValues();
281 if (attributeNames.size() != attributeValues.size()) {
282 return emitOpError()
283 << "expected the same number of attribute values and attribute "
284 "names, got "
285 << attributeNames.size() << " names and " << attributeValues.size()
286 << " values";
287 }
288
289 // If the operation is within a rewrite body and doesn't have type inference,
290 // ensure that the result types can be resolved.
291 if (isWithinRewrite && !mightHaveTypeInference()) {
292 if (failed(verifyResultTypesAreInferrable(*this, getTypeValues())))
293 return failure();
294 }
295
296 return verifyHasBindingUse(*this);
297}
298
299bool OperationOp::hasTypeInference() {
300 if (std::optional<StringRef> rawOpName = getOpName()) {
301 OperationName opName(*rawOpName, getContext());
302 return opName.hasInterface<InferTypeOpInterface>();
303 }
304 return false;
305}
306
307bool OperationOp::mightHaveTypeInference() {
308 if (std::optional<StringRef> rawOpName = getOpName()) {
309 OperationName opName(*rawOpName, getContext());
310 return opName.mightHaveInterface<InferTypeOpInterface>();
311 }
312 return false;
313}
314
315//===----------------------------------------------------------------------===//
316// pdl::PatternOp
317//===----------------------------------------------------------------------===//
318
319LogicalResult PatternOp::verifyRegions() {
320 Region &body = getBodyRegion();
321 Operation *term = body.front().getTerminator();
322 auto rewriteOp = dyn_cast<RewriteOp>(term);
323 if (!rewriteOp) {
324 return emitOpError("expected body to terminate with `pdl.rewrite`")
325 .attachNote(term->getLoc())
326 .append("see terminator defined here");
327 }
328
329 // Check that all values defined in the top-level pattern belong to the PDL
330 // dialect.
331 WalkResult result = body.walk([&](Operation *op) -> WalkResult {
332 if (!isa_and_nonnull<PDLDialect>(op->getDialect())) {
333 emitOpError("expected only `pdl` operations within the pattern body")
334 .attachNote(op->getLoc())
335 .append("see non-`pdl` operation defined here");
336 return WalkResult::interrupt();
337 }
338 return WalkResult::advance();
339 });
340 if (result.wasInterrupted())
341 return failure();
342
343 // Check that there is at least one operation.
344 if (body.front().getOps<OperationOp>().empty())
345 return emitOpError("the pattern must contain at least one `pdl.operation`");
346
347 // Determine if the operations within the pdl.pattern form a connected
348 // component. This is determined by starting the search from the first
349 // operand/result/operation and visiting their users / parents / operands.
350 // We limit our attention to operations that have a user in pdl.rewrite,
351 // those that do not will be detected via other means (expected bindable
352 // user).
353 bool first = true;
354 DenseSet<Operation *> visited;
355 for (Operation &op : body.front()) {
356 // The following are the operations forming the connected component.
357 if (!isa<OperandOp, OperandsOp, ResultOp, ResultsOp, OperationOp>(op))
358 continue;
359
360 // Determine if the operation has a user in `pdl.rewrite`.
361 bool hasUserInRewrite = false;
362 for (Operation *user : op.getUsers()) {
363 Region *region = user->getParentRegion();
364 if (isa<RewriteOp>(user) ||
365 (region && isa<RewriteOp>(region->getParentOp()))) {
366 hasUserInRewrite = true;
367 break;
368 }
369 }
370
371 // If the operation does not have a user in `pdl.rewrite`, ignore it.
372 if (!hasUserInRewrite)
373 continue;
374
375 if (first) {
376 // For the first operation, invoke visit.
377 visit(&op, visited);
378 first = false;
379 } else if (!visited.count(&op)) {
380 // For the subsequent operations, check if already visited.
381 return emitOpError("the operations must form a connected component")
382 .attachNote(op.getLoc())
383 .append("see a disconnected value / operation here");
384 }
385 }
386
387 return success();
388}
389
390void PatternOp::build(OpBuilder &builder, OperationState &state,
391 std::optional<uint16_t> benefit,
392 std::optional<StringRef> name) {
393 build(builder, state, builder.getI16IntegerAttr(benefit ? *benefit : 0),
394 name ? builder.getStringAttr(*name) : StringAttr());
395 state.regions[0]->emplaceBlock();
396}
397
398/// Returns the rewrite operation of this pattern.
399RewriteOp PatternOp::getRewriter() {
400 return cast<RewriteOp>(getBodyRegion().front().getTerminator());
401}
402
403/// The default dialect is `pdl`.
404StringRef PatternOp::getDefaultDialect() {
405 return PDLDialect::getDialectNamespace();
406}
407
408//===----------------------------------------------------------------------===//
409// pdl::RangeOp
410//===----------------------------------------------------------------------===//
411
412static ParseResult parseRangeType(OpAsmParser &p, TypeRange argumentTypes,
413 Type &resultType) {
414 // If arguments were provided, infer the result type from the argument list.
415 if (!argumentTypes.empty()) {
416 resultType = RangeType::get(getRangeElementTypeOrSelf(argumentTypes[0]));
417 return success();
418 }
419 // Otherwise, parse the type as a trailing type.
420 return p.parseColonType(result&: resultType);
421}
422
423static void printRangeType(OpAsmPrinter &p, RangeOp op, TypeRange argumentTypes,
424 Type resultType) {
425 if (argumentTypes.empty())
426 p << ": " << resultType;
427}
428
429LogicalResult RangeOp::verify() {
430 Type elementType = getType().getElementType();
431 for (Type operandType : getOperandTypes()) {
432 Type operandElementType = getRangeElementTypeOrSelf(operandType);
433 if (operandElementType != elementType) {
434 return emitOpError("expected operand to have element type ")
435 << elementType << ", but got " << operandElementType;
436 }
437 }
438 return success();
439}
440
441//===----------------------------------------------------------------------===//
442// pdl::ReplaceOp
443//===----------------------------------------------------------------------===//
444
445LogicalResult ReplaceOp::verify() {
446 if (getReplOperation() && !getReplValues().empty())
447 return emitOpError() << "expected no replacement values to be provided"
448 " when the replacement operation is present";
449 return success();
450}
451
452//===----------------------------------------------------------------------===//
453// pdl::ResultsOp
454//===----------------------------------------------------------------------===//
455
456static ParseResult parseResultsValueType(OpAsmParser &p, IntegerAttr index,
457 Type &resultType) {
458 if (!index) {
459 resultType = RangeType::get(p.getBuilder().getType<ValueType>());
460 return success();
461 }
462 if (p.parseArrow() || p.parseType(result&: resultType))
463 return failure();
464 return success();
465}
466
467static void printResultsValueType(OpAsmPrinter &p, ResultsOp op,
468 IntegerAttr index, Type resultType) {
469 if (index)
470 p << " -> " << resultType;
471}
472
473LogicalResult ResultsOp::verify() {
474 if (!getIndex() && llvm::isa<pdl::ValueType>(getType())) {
475 return emitOpError() << "expected `pdl.range<value>` result type when "
476 "no index is specified, but got: "
477 << getType();
478 }
479 return success();
480}
481
482//===----------------------------------------------------------------------===//
483// pdl::RewriteOp
484//===----------------------------------------------------------------------===//
485
486LogicalResult RewriteOp::verifyRegions() {
487 Region &rewriteRegion = getBodyRegion();
488
489 // Handle the case where the rewrite is external.
490 if (getName()) {
491 if (!rewriteRegion.empty()) {
492 return emitOpError()
493 << "expected rewrite region to be empty when rewrite is external";
494 }
495 return success();
496 }
497
498 // Otherwise, check that the rewrite region only contains a single block.
499 if (rewriteRegion.empty()) {
500 return emitOpError() << "expected rewrite region to be non-empty if "
501 "external name is not specified";
502 }
503
504 // Check that no additional arguments were provided.
505 if (!getExternalArgs().empty()) {
506 return emitOpError() << "expected no external arguments when the "
507 "rewrite is specified inline";
508 }
509
510 return success();
511}
512
513/// The default dialect is `pdl`.
514StringRef RewriteOp::getDefaultDialect() {
515 return PDLDialect::getDialectNamespace();
516}
517
518//===----------------------------------------------------------------------===//
519// pdl::TypeOp
520//===----------------------------------------------------------------------===//
521
522LogicalResult TypeOp::verify() {
523 if (!getConstantTypeAttr())
524 return verifyHasBindingUse(*this);
525 return success();
526}
527
528//===----------------------------------------------------------------------===//
529// pdl::TypesOp
530//===----------------------------------------------------------------------===//
531
532LogicalResult TypesOp::verify() {
533 if (!getConstantTypesAttr())
534 return verifyHasBindingUse(*this);
535 return success();
536}
537
538//===----------------------------------------------------------------------===//
539// TableGen'd op method definitions
540//===----------------------------------------------------------------------===//
541
542#define GET_OP_CLASSES
543#include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc"
544

source code of mlir/lib/Dialect/PDL/IR/PDL.cpp