1//===- PDL.cpp - Pattern Descriptor Language Dialect ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/PDL/IR/PDL.h"
10#include "mlir/Dialect/PDL/IR/PDLOps.h"
11#include "mlir/Dialect/PDL/IR/PDLTypes.h"
12#include "mlir/IR/BuiltinTypes.h"
13#include "mlir/Interfaces/InferTypeOpInterface.h"
14#include "llvm/ADT/TypeSwitch.h"
15#include <optional>
16
17using namespace mlir;
18using namespace mlir::pdl;
19
20#include "mlir/Dialect/PDL/IR/PDLOpsDialect.cpp.inc"
21
22//===----------------------------------------------------------------------===//
23// PDLDialect
24//===----------------------------------------------------------------------===//
25
26void PDLDialect::initialize() {
27 addOperations<
28#define GET_OP_LIST
29#include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc"
30 >();
31 registerTypes();
32}
33
34//===----------------------------------------------------------------------===//
35// PDL Operations
36//===----------------------------------------------------------------------===//
37
38/// Returns true if the given operation is used by a "binding" pdl operation.
39static bool hasBindingUse(Operation *op) {
40 for (Operation *user : op->getUsers())
41 // A result by itself is not binding, it must also be bound.
42 if (!isa<ResultOp, ResultsOp>(Val: user) || hasBindingUse(op: user))
43 return true;
44 return false;
45}
46
47/// Returns success if the given operation is not in the main matcher body or
48/// is used by a "binding" operation. On failure, emits an error.
49static LogicalResult verifyHasBindingUse(Operation *op) {
50 // If the parent is not a pattern, there is nothing to do.
51 if (!llvm::isa_and_nonnull<PatternOp>(Val: op->getParentOp()))
52 return success();
53 if (hasBindingUse(op))
54 return success();
55 return op->emitOpError(
56 message: "expected a bindable user when defined in the matcher body of a "
57 "`pdl.pattern`");
58}
59
60/// Visits all the pdl.operand(s), pdl.result(s), and pdl.operation(s)
61/// connected to the given operation.
62static void visit(Operation *op, DenseSet<Operation *> &visited) {
63 // If the parent is not a pattern, there is nothing to do.
64 if (!isa<PatternOp>(Val: op->getParentOp()) || isa<RewriteOp>(Val: op))
65 return;
66
67 // Ignore if already visited. Otherwise, mark as visited.
68 if (!visited.insert(V: op).second)
69 return;
70
71 // Traverse the operands / parent.
72 TypeSwitch<Operation *>(op)
73 .Case<OperationOp>(caseFn: [&visited](auto operation) {
74 for (Value operand : operation.getOperandValues())
75 visit(op: operand.getDefiningOp(), visited);
76 })
77 .Case<ResultOp, ResultsOp>(caseFn: [&visited](auto result) {
78 visit(result.getParent().getDefiningOp(), visited);
79 });
80
81 // Traverse the users.
82 for (Operation *user : op->getUsers())
83 visit(op: user, visited);
84}
85
86//===----------------------------------------------------------------------===//
87// pdl::ApplyNativeConstraintOp
88//===----------------------------------------------------------------------===//
89
90LogicalResult ApplyNativeConstraintOp::verify() {
91 if (getNumOperands() == 0)
92 return emitOpError(message: "expected at least one argument");
93 if (llvm::any_of(Range: getResults(), P: [](OpResult result) {
94 return isa<OperationType>(Val: result.getType());
95 })) {
96 return emitOpError(
97 message: "returning an operation from a constraint is not supported");
98 }
99 return success();
100}
101
102//===----------------------------------------------------------------------===//
103// pdl::ApplyNativeRewriteOp
104//===----------------------------------------------------------------------===//
105
106LogicalResult ApplyNativeRewriteOp::verify() {
107 if (getNumOperands() == 0 && getNumResults() == 0)
108 return emitOpError(message: "expected at least one argument or result");
109 return success();
110}
111
112//===----------------------------------------------------------------------===//
113// pdl::AttributeOp
114//===----------------------------------------------------------------------===//
115
116LogicalResult AttributeOp::verify() {
117 Value attrType = getValueType();
118 std::optional<Attribute> attrValue = getValue();
119
120 if (!attrValue) {
121 if (isa<RewriteOp>(Val: (*this)->getParentOp()))
122 return emitOpError(
123 message: "expected constant value when specified within a `pdl.rewrite`");
124 return verifyHasBindingUse(op: *this);
125 }
126 if (attrType)
127 return emitOpError(message: "expected only one of [`type`, `value`] to be set");
128 return success();
129}
130
131//===----------------------------------------------------------------------===//
132// pdl::OperandOp
133//===----------------------------------------------------------------------===//
134
135LogicalResult OperandOp::verify() { return verifyHasBindingUse(op: *this); }
136
137//===----------------------------------------------------------------------===//
138// pdl::OperandsOp
139//===----------------------------------------------------------------------===//
140
141LogicalResult OperandsOp::verify() { return verifyHasBindingUse(op: *this); }
142
143//===----------------------------------------------------------------------===//
144// pdl::OperationOp
145//===----------------------------------------------------------------------===//
146
147static ParseResult parseOperationOpAttributes(
148 OpAsmParser &p,
149 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &attrOperands,
150 ArrayAttr &attrNamesAttr) {
151 Builder &builder = p.getBuilder();
152 SmallVector<Attribute, 4> attrNames;
153 if (succeeded(Result: p.parseOptionalLBrace())) {
154 auto parseOperands = [&]() {
155 StringAttr nameAttr;
156 OpAsmParser::UnresolvedOperand operand;
157 if (p.parseAttribute(result&: nameAttr) || p.parseEqual() ||
158 p.parseOperand(result&: operand))
159 return failure();
160 attrNames.push_back(Elt: nameAttr);
161 attrOperands.push_back(Elt: operand);
162 return success();
163 };
164 if (p.parseCommaSeparatedList(parseElementFn: parseOperands) || p.parseRBrace())
165 return failure();
166 }
167 attrNamesAttr = builder.getArrayAttr(value: attrNames);
168 return success();
169}
170
171static void printOperationOpAttributes(OpAsmPrinter &p, OperationOp op,
172 OperandRange attrArgs,
173 ArrayAttr attrNames) {
174 if (attrNames.empty())
175 return;
176 p << " {";
177 interleaveComma(c: llvm::seq<int>(Begin: 0, End: attrNames.size()), os&: p,
178 each_fn: [&](int i) { p << attrNames[i] << " = " << attrArgs[i]; });
179 p << '}';
180}
181
182/// Verifies that the result types of this operation, defined within a
183/// `pdl.rewrite`, can be inferred.
184static LogicalResult verifyResultTypesAreInferrable(OperationOp op,
185 OperandRange resultTypes) {
186 // Functor that returns if the given use can be used to infer a type.
187 Block *rewriterBlock = op->getBlock();
188 auto canInferTypeFromUse = [&](OpOperand &use) {
189 // If the use is within a ReplaceOp and isn't the operation being replaced
190 // (i.e. is not the first operand of the replacement), we can infer a type.
191 ReplaceOp replOpUser = dyn_cast<ReplaceOp>(Val: use.getOwner());
192 if (!replOpUser || use.getOperandNumber() == 0)
193 return false;
194 // Make sure the replaced operation was defined before this one.
195 Operation *replacedOp = replOpUser.getOpValue().getDefiningOp();
196 return replacedOp->getBlock() != rewriterBlock ||
197 replacedOp->isBeforeInBlock(other: op);
198 };
199
200 // Check to see if the uses of the operation itself can be used to infer
201 // types.
202 if (llvm::any_of(Range: op.getOp().getUses(), P: canInferTypeFromUse))
203 return success();
204
205 // Handle the case where the operation has no explicit result types.
206 if (resultTypes.empty()) {
207 // If we don't know the concrete operation, don't attempt any verification.
208 // We can't make assumptions if we don't know the concrete operation.
209 std::optional<StringRef> rawOpName = op.getOpName();
210 if (!rawOpName)
211 return success();
212 std::optional<RegisteredOperationName> opName =
213 RegisteredOperationName::lookup(name: *rawOpName, ctx: op.getContext());
214 if (!opName)
215 return success();
216
217 // If no explicit result types were provided, check to see if the operation
218 // expected at least one result. This doesn't cover all cases, but this
219 // should cover many cases in which the user intended to infer the results
220 // of an operation, but it isn't actually possible.
221 bool expectedAtLeastOneResult =
222 !opName->hasTrait<OpTrait::ZeroResults>() &&
223 !opName->hasTrait<OpTrait::VariadicResults>();
224 if (expectedAtLeastOneResult) {
225 return op
226 .emitOpError(message: "must have inferable or constrained result types when "
227 "nested within `pdl.rewrite`")
228 .attachNote()
229 .append(arg1: "operation is created in a non-inferrable context, but '",
230 arg2&: *opName, args: "' does not implement InferTypeOpInterface");
231 }
232 return success();
233 }
234
235 // Otherwise, make sure each of the types can be inferred.
236 for (const auto &it : llvm::enumerate(First&: resultTypes)) {
237 Operation *resultTypeOp = it.value().getDefiningOp();
238 assert(resultTypeOp && "expected valid result type operation");
239
240 // If the op was defined by a `apply_native_rewrite`, it is guaranteed to be
241 // usable.
242 if (isa<ApplyNativeRewriteOp>(Val: resultTypeOp))
243 continue;
244
245 // If the type operation was defined in the matcher and constrains an
246 // operand or the result of an input operation, it can be used.
247 auto constrainsInput = [rewriterBlock](Operation *user) {
248 return user->getBlock() != rewriterBlock &&
249 isa<OperandOp, OperandsOp, OperationOp>(Val: user);
250 };
251 if (TypeOp typeOp = dyn_cast<TypeOp>(Val: resultTypeOp)) {
252 if (typeOp.getConstantType() ||
253 llvm::any_of(Range: typeOp->getUsers(), P: constrainsInput))
254 continue;
255 } else if (TypesOp typeOp = dyn_cast<TypesOp>(Val: resultTypeOp)) {
256 if (typeOp.getConstantTypes() ||
257 llvm::any_of(Range: typeOp->getUsers(), P: constrainsInput))
258 continue;
259 }
260
261 return op
262 .emitOpError(message: "must have inferable or constrained result types when "
263 "nested within `pdl.rewrite`")
264 .attachNote()
265 .append(arg1: "result type #", arg2: it.index(), args: " was not constrained");
266 }
267 return success();
268}
269
270LogicalResult OperationOp::verify() {
271 bool isWithinRewrite = isa_and_nonnull<RewriteOp>(Val: (*this)->getParentOp());
272 if (isWithinRewrite && !getOpName())
273 return emitOpError(message: "must have an operation name when nested within "
274 "a `pdl.rewrite`");
275 ArrayAttr attributeNames = getAttributeValueNamesAttr();
276 auto attributeValues = getAttributeValues();
277 if (attributeNames.size() != attributeValues.size()) {
278 return emitOpError()
279 << "expected the same number of attribute values and attribute "
280 "names, got "
281 << attributeNames.size() << " names and " << attributeValues.size()
282 << " values";
283 }
284
285 // If the operation is within a rewrite body and doesn't have type inference,
286 // ensure that the result types can be resolved.
287 if (isWithinRewrite && !mightHaveTypeInference()) {
288 if (failed(Result: verifyResultTypesAreInferrable(op: *this, resultTypes: getTypeValues())))
289 return failure();
290 }
291
292 return verifyHasBindingUse(op: *this);
293}
294
295bool OperationOp::hasTypeInference() {
296 if (std::optional<StringRef> rawOpName = getOpName()) {
297 OperationName opName(*rawOpName, getContext());
298 return opName.hasInterface<InferTypeOpInterface>();
299 }
300 return false;
301}
302
303bool OperationOp::mightHaveTypeInference() {
304 if (std::optional<StringRef> rawOpName = getOpName()) {
305 OperationName opName(*rawOpName, getContext());
306 return opName.mightHaveInterface<InferTypeOpInterface>();
307 }
308 return false;
309}
310
311//===----------------------------------------------------------------------===//
312// pdl::PatternOp
313//===----------------------------------------------------------------------===//
314
315LogicalResult PatternOp::verifyRegions() {
316 Region &body = getBodyRegion();
317 Operation *term = body.front().getTerminator();
318 auto rewriteOp = dyn_cast<RewriteOp>(Val: term);
319 if (!rewriteOp) {
320 return emitOpError(message: "expected body to terminate with `pdl.rewrite`")
321 .attachNote(noteLoc: term->getLoc())
322 .append(arg: "see terminator defined here");
323 }
324
325 // Check that all values defined in the top-level pattern belong to the PDL
326 // dialect.
327 WalkResult result = body.walk(callback: [&](Operation *op) -> WalkResult {
328 if (!isa_and_nonnull<PDLDialect>(Val: op->getDialect())) {
329 emitOpError(message: "expected only `pdl` operations within the pattern body")
330 .attachNote(noteLoc: op->getLoc())
331 .append(arg: "see non-`pdl` operation defined here");
332 return WalkResult::interrupt();
333 }
334 return WalkResult::advance();
335 });
336 if (result.wasInterrupted())
337 return failure();
338
339 // Check that there is at least one operation.
340 if (body.front().getOps<OperationOp>().empty())
341 return emitOpError(message: "the pattern must contain at least one `pdl.operation`");
342
343 // Determine if the operations within the pdl.pattern form a connected
344 // component. This is determined by starting the search from the first
345 // operand/result/operation and visiting their users / parents / operands.
346 // We limit our attention to operations that have a user in pdl.rewrite,
347 // those that do not will be detected via other means (expected bindable
348 // user).
349 bool first = true;
350 DenseSet<Operation *> visited;
351 for (Operation &op : body.front()) {
352 // The following are the operations forming the connected component.
353 if (!isa<OperandOp, OperandsOp, ResultOp, ResultsOp, OperationOp>(Val: op))
354 continue;
355
356 // Determine if the operation has a user in `pdl.rewrite`.
357 bool hasUserInRewrite = false;
358 for (Operation *user : op.getUsers()) {
359 Region *region = user->getParentRegion();
360 if (isa<RewriteOp>(Val: user) ||
361 (region && isa<RewriteOp>(Val: region->getParentOp()))) {
362 hasUserInRewrite = true;
363 break;
364 }
365 }
366
367 // If the operation does not have a user in `pdl.rewrite`, ignore it.
368 if (!hasUserInRewrite)
369 continue;
370
371 if (first) {
372 // For the first operation, invoke visit.
373 visit(op: &op, visited);
374 first = false;
375 } else if (!visited.count(V: &op)) {
376 // For the subsequent operations, check if already visited.
377 return emitOpError(message: "the operations must form a connected component")
378 .attachNote(noteLoc: op.getLoc())
379 .append(arg: "see a disconnected value / operation here");
380 }
381 }
382
383 return success();
384}
385
386void PatternOp::build(OpBuilder &builder, OperationState &state,
387 std::optional<uint16_t> benefit,
388 std::optional<StringRef> name) {
389 build(odsBuilder&: builder, odsState&: state, benefit: builder.getI16IntegerAttr(value: benefit.value_or(u: 0)),
390 sym_name: name ? builder.getStringAttr(bytes: *name) : StringAttr());
391 state.regions[0]->emplaceBlock();
392}
393
394/// Returns the rewrite operation of this pattern.
395RewriteOp PatternOp::getRewriter() {
396 return cast<RewriteOp>(Val: getBodyRegion().front().getTerminator());
397}
398
399/// The default dialect is `pdl`.
400StringRef PatternOp::getDefaultDialect() {
401 return PDLDialect::getDialectNamespace();
402}
403
404//===----------------------------------------------------------------------===//
405// pdl::RangeOp
406//===----------------------------------------------------------------------===//
407
408static ParseResult parseRangeType(OpAsmParser &p, TypeRange argumentTypes,
409 Type &resultType) {
410 // If arguments were provided, infer the result type from the argument list.
411 if (!argumentTypes.empty()) {
412 resultType = RangeType::get(elementType: getRangeElementTypeOrSelf(type: argumentTypes[0]));
413 return success();
414 }
415 // Otherwise, parse the type as a trailing type.
416 return p.parseColonType(result&: resultType);
417}
418
419static void printRangeType(OpAsmPrinter &p, RangeOp op, TypeRange argumentTypes,
420 Type resultType) {
421 if (argumentTypes.empty())
422 p << ": " << resultType;
423}
424
425LogicalResult RangeOp::verify() {
426 Type elementType = getType().getElementType();
427 for (Type operandType : getOperandTypes()) {
428 Type operandElementType = getRangeElementTypeOrSelf(type: operandType);
429 if (operandElementType != elementType) {
430 return emitOpError(message: "expected operand to have element type ")
431 << elementType << ", but got " << operandElementType;
432 }
433 }
434 return success();
435}
436
437//===----------------------------------------------------------------------===//
438// pdl::ReplaceOp
439//===----------------------------------------------------------------------===//
440
441LogicalResult ReplaceOp::verify() {
442 if (getReplOperation() && !getReplValues().empty())
443 return emitOpError() << "expected no replacement values to be provided"
444 " when the replacement operation is present";
445 return success();
446}
447
448//===----------------------------------------------------------------------===//
449// pdl::ResultsOp
450//===----------------------------------------------------------------------===//
451
452static ParseResult parseResultsValueType(OpAsmParser &p, IntegerAttr index,
453 Type &resultType) {
454 if (!index) {
455 resultType = RangeType::get(elementType: p.getBuilder().getType<ValueType>());
456 return success();
457 }
458 if (p.parseArrow() || p.parseType(result&: resultType))
459 return failure();
460 return success();
461}
462
463static void printResultsValueType(OpAsmPrinter &p, ResultsOp op,
464 IntegerAttr index, Type resultType) {
465 if (index)
466 p << " -> " << resultType;
467}
468
469LogicalResult ResultsOp::verify() {
470 if (!getIndex() && llvm::isa<pdl::ValueType>(Val: getType())) {
471 return emitOpError() << "expected `pdl.range<value>` result type when "
472 "no index is specified, but got: "
473 << getType();
474 }
475 return success();
476}
477
478//===----------------------------------------------------------------------===//
479// pdl::RewriteOp
480//===----------------------------------------------------------------------===//
481
482LogicalResult RewriteOp::verifyRegions() {
483 Region &rewriteRegion = getBodyRegion();
484
485 // Handle the case where the rewrite is external.
486 if (getName()) {
487 if (!rewriteRegion.empty()) {
488 return emitOpError()
489 << "expected rewrite region to be empty when rewrite is external";
490 }
491 return success();
492 }
493
494 // Otherwise, check that the rewrite region only contains a single block.
495 if (rewriteRegion.empty()) {
496 return emitOpError() << "expected rewrite region to be non-empty if "
497 "external name is not specified";
498 }
499
500 // Check that no additional arguments were provided.
501 if (!getExternalArgs().empty()) {
502 return emitOpError() << "expected no external arguments when the "
503 "rewrite is specified inline";
504 }
505
506 return success();
507}
508
509/// The default dialect is `pdl`.
510StringRef RewriteOp::getDefaultDialect() {
511 return PDLDialect::getDialectNamespace();
512}
513
514//===----------------------------------------------------------------------===//
515// pdl::TypeOp
516//===----------------------------------------------------------------------===//
517
518LogicalResult TypeOp::verify() {
519 if (!getConstantTypeAttr())
520 return verifyHasBindingUse(op: *this);
521 return success();
522}
523
524//===----------------------------------------------------------------------===//
525// pdl::TypesOp
526//===----------------------------------------------------------------------===//
527
528LogicalResult TypesOp::verify() {
529 if (!getConstantTypesAttr())
530 return verifyHasBindingUse(op: *this);
531 return success();
532}
533
534//===----------------------------------------------------------------------===//
535// TableGen'd op method definitions
536//===----------------------------------------------------------------------===//
537
538#define GET_OP_CLASSES
539#include "mlir/Dialect/PDL/IR/PDLOps.cpp.inc"
540

source code of mlir/lib/Dialect/PDL/IR/PDL.cpp