1 | //===- IRDLExtensionOps.cpp - IRDL extension for the Transform dialect ----===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/Transform/IRDLExtension/IRDLExtensionOps.h" |
10 | #include "mlir/Dialect/IRDL/IR/IRDL.h" |
11 | #include "mlir/Dialect/IRDL/IRDLVerifiers.h" |
12 | #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" |
13 | #include "mlir/IR/Diagnostics.h" |
14 | #include "mlir/IR/ExtensibleDialect.h" |
15 | #include "mlir/Interfaces/SideEffectInterfaces.h" |
16 | #include "llvm/ADT/STLExtras.h" |
17 | |
18 | using namespace mlir; |
19 | |
20 | #define GET_OP_CLASSES |
21 | #include "mlir/Dialect/Transform/IRDLExtension/IRDLExtensionOps.cpp.inc" |
22 | |
23 | namespace mlir::transform { |
24 | |
25 | DiagnosedSilenceableFailure |
26 | IRDLCollectMatchingOp::apply(TransformRewriter &rewriter, |
27 | TransformResults &results, TransformState &state) { |
28 | auto dialect = cast<irdl::DialectOp>(getBody().front().front()); |
29 | Block &body = dialect.getBody().front(); |
30 | irdl::OperationOp operation = *body.getOps<irdl::OperationOp>().begin(); |
31 | auto verifier = irdl::createVerifier( |
32 | operation, |
33 | DenseMap<irdl::TypeOp, std::unique_ptr<DynamicTypeDefinition>>(), |
34 | DenseMap<irdl::AttributeOp, std::unique_ptr<DynamicAttrDefinition>>()); |
35 | |
36 | auto handlerID = getContext()->getDiagEngine().registerHandler( |
37 | [](Diagnostic &) { return success(); }); |
38 | SmallVector<Operation *> matched; |
39 | for (Operation *payload : state.getPayloadOps(getRoot())) { |
40 | payload->walk([&](Operation *target) { |
41 | if (succeeded(verifier(target))) { |
42 | matched.push_back(target); |
43 | } |
44 | }); |
45 | } |
46 | getContext()->getDiagEngine().eraseHandler(handlerID); |
47 | results.set(cast<OpResult>(getMatched()), matched); |
48 | return DiagnosedSilenceableFailure::success(); |
49 | } |
50 | |
51 | void IRDLCollectMatchingOp::getEffects( |
52 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
53 | onlyReadsHandle(getRootMutable(), effects); |
54 | producesHandle(getOperation()->getOpResults(), effects); |
55 | onlyReadsPayload(effects); |
56 | } |
57 | |
58 | LogicalResult IRDLCollectMatchingOp::verify() { |
59 | Block &bodyBlock = getBody().front(); |
60 | if (!llvm::hasSingleElement(bodyBlock)) |
61 | return emitOpError() << "expects a single operation in the body" ; |
62 | |
63 | auto dialect = dyn_cast<irdl::DialectOp>(bodyBlock.front()); |
64 | if (!dialect) { |
65 | return emitOpError() << "expects the body operation to be " |
66 | << irdl::DialectOp::getOperationName(); |
67 | } |
68 | |
69 | // TODO: relax this by taking a symbol name of the operation to match, note |
70 | // that symbol name is also the name of the operation and we may want to |
71 | // divert from that to have constraints on-the-fly using IRDL. |
72 | auto irdlOperations = dialect.getOps<irdl::OperationOp>(); |
73 | if (!llvm::hasSingleElement(irdlOperations)) |
74 | return emitOpError() << "expects IRDL to contain exactly one operation" ; |
75 | |
76 | if (!dialect.getOps<irdl::TypeOp>().empty() || |
77 | !dialect.getOps<irdl::AttributeOp>().empty()) { |
78 | return emitOpError() << "IRDL types and attributes are not yet supported" ; |
79 | } |
80 | |
81 | return success(); |
82 | } |
83 | |
84 | } // namespace mlir::transform |
85 | |