| 1 | //===- PDLExtensionOps.cpp - PDL extension for the Transform dialect ------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h" |
| 10 | #include "mlir/Dialect/PDL/IR/PDLOps.h" |
| 11 | #include "mlir/IR/Builders.h" |
| 12 | #include "mlir/Rewrite/FrozenRewritePatternSet.h" |
| 13 | #include "mlir/Rewrite/PatternApplicator.h" |
| 14 | #include "llvm/ADT/ScopeExit.h" |
| 15 | |
| 16 | using namespace mlir; |
| 17 | |
| 18 | MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::transform::PDLMatchHooks) |
| 19 | |
| 20 | #define GET_OP_CLASSES |
| 21 | #include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp.inc" |
| 22 | |
| 23 | //===----------------------------------------------------------------------===// |
| 24 | // PatternApplicatorExtension |
| 25 | //===----------------------------------------------------------------------===// |
| 26 | |
| 27 | namespace { |
| 28 | /// A TransformState extension that keeps track of compiled PDL pattern sets. |
| 29 | /// This is intended to be used along the WithPDLPatterns op. The extension |
| 30 | /// can be constructed given an operation that has a SymbolTable trait and |
| 31 | /// contains pdl::PatternOp instances. The patterns are compiled lazily and one |
| 32 | /// by one when requested; this behavior is subject to change. |
| 33 | class PatternApplicatorExtension : public transform::TransformState::Extension { |
| 34 | public: |
| 35 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PatternApplicatorExtension) |
| 36 | |
| 37 | /// Creates the extension for patterns contained in `patternContainer`. |
| 38 | explicit PatternApplicatorExtension(transform::TransformState &state, |
| 39 | Operation *patternContainer) |
| 40 | : Extension(state), patterns(patternContainer) {} |
| 41 | |
| 42 | /// Appends to `results` the operations contained in `root` that matched the |
| 43 | /// PDL pattern with the given name. Note that `root` may or may not be the |
| 44 | /// operation that contains PDL patterns. Reports an error if the pattern |
| 45 | /// cannot be found. Note that when no operations are matched, this still |
| 46 | /// succeeds as long as the pattern exists. |
| 47 | LogicalResult findAllMatches(StringRef patternName, Operation *root, |
| 48 | SmallVectorImpl<Operation *> &results); |
| 49 | |
| 50 | private: |
| 51 | /// Map from the pattern name to a singleton set of rewrite patterns that only |
| 52 | /// contains the pattern with this name. Populated when the pattern is first |
| 53 | /// requested. |
| 54 | // TODO: reconsider the efficiency of this storage when more usage data is |
| 55 | // available. Storing individual patterns in a set and triggering compilation |
| 56 | // for each of them has overhead. So does compiling a large set of patterns |
| 57 | // only to apply a handful of them. |
| 58 | llvm::StringMap<FrozenRewritePatternSet> compiledPatterns; |
| 59 | |
| 60 | /// A symbol table operation containing the relevant PDL patterns. |
| 61 | SymbolTable patterns; |
| 62 | }; |
| 63 | |
| 64 | LogicalResult PatternApplicatorExtension::findAllMatches( |
| 65 | StringRef patternName, Operation *root, |
| 66 | SmallVectorImpl<Operation *> &results) { |
| 67 | auto it = compiledPatterns.find(Key: patternName); |
| 68 | if (it == compiledPatterns.end()) { |
| 69 | auto patternOp = patterns.lookup<pdl::PatternOp>(name: patternName); |
| 70 | if (!patternOp) |
| 71 | return failure(); |
| 72 | |
| 73 | // Copy the pattern operation into a new module that is compiled and |
| 74 | // consumed by the PDL interpreter. |
| 75 | OwningOpRef<ModuleOp> pdlModuleOp = ModuleOp::create(loc: patternOp.getLoc()); |
| 76 | auto builder = OpBuilder::atBlockEnd(block: pdlModuleOp->getBody()); |
| 77 | builder.clone(op&: *patternOp); |
| 78 | PDLPatternModule patternModule(std::move(pdlModuleOp)); |
| 79 | |
| 80 | // Merge in the hooks owned by the dialect. Make a copy as they may be |
| 81 | // also used by the following operations. |
| 82 | auto *dialect = |
| 83 | root->getContext()->getLoadedDialect<transform::TransformDialect>(); |
| 84 | for (const auto &[name, constraintFn] : |
| 85 | dialect->getExtraData<transform::PDLMatchHooks>() |
| 86 | .getPDLConstraintHooks()) { |
| 87 | patternModule.registerConstraintFunction(name, constraintFn); |
| 88 | } |
| 89 | |
| 90 | // Register a noop rewriter because PDL requires patterns to end with some |
| 91 | // rewrite call. |
| 92 | patternModule.registerRewriteFunction( |
| 93 | name: "transform.dialect" , rewriteFn: [](PatternRewriter &, Operation *) {}); |
| 94 | |
| 95 | it = compiledPatterns |
| 96 | .try_emplace(Key: patternOp.getName(), Args: std::move(patternModule)) |
| 97 | .first; |
| 98 | } |
| 99 | |
| 100 | PatternApplicator applicator(it->second); |
| 101 | // We want to discourage direct use of PatternRewriter in APIs but In this |
| 102 | // very specific case, an IRRewriter is not enough. |
| 103 | struct TrivialPatternRewriter : public PatternRewriter { |
| 104 | public: |
| 105 | explicit TrivialPatternRewriter(MLIRContext *context) |
| 106 | : PatternRewriter(context) {} |
| 107 | }; |
| 108 | TrivialPatternRewriter rewriter(root->getContext()); |
| 109 | applicator.applyDefaultCostModel(); |
| 110 | root->walk(callback: [&](Operation *op) { |
| 111 | if (succeeded(Result: applicator.matchAndRewrite(op, rewriter))) |
| 112 | results.push_back(Elt: op); |
| 113 | }); |
| 114 | |
| 115 | return success(); |
| 116 | } |
| 117 | } // namespace |
| 118 | |
| 119 | //===----------------------------------------------------------------------===// |
| 120 | // PDLMatchHooks |
| 121 | //===----------------------------------------------------------------------===// |
| 122 | |
| 123 | void transform::PDLMatchHooks::mergeInPDLMatchHooks( |
| 124 | llvm::StringMap<PDLConstraintFunction> &&constraintFns) { |
| 125 | // Steal the constraint functions from the given map. |
| 126 | for (auto &it : constraintFns) |
| 127 | pdlMatchHooks.registerConstraintFunction(name: it.getKey(), constraintFn: std::move(it.second)); |
| 128 | } |
| 129 | |
| 130 | const llvm::StringMap<PDLConstraintFunction> & |
| 131 | transform::PDLMatchHooks::getPDLConstraintHooks() const { |
| 132 | return pdlMatchHooks.getConstraintFunctions(); |
| 133 | } |
| 134 | |
| 135 | //===----------------------------------------------------------------------===// |
| 136 | // PDLMatchOp |
| 137 | //===----------------------------------------------------------------------===// |
| 138 | |
| 139 | DiagnosedSilenceableFailure |
| 140 | transform::PDLMatchOp::apply(transform::TransformRewriter &rewriter, |
| 141 | transform::TransformResults &results, |
| 142 | transform::TransformState &state) { |
| 143 | auto *extension = state.getExtension<PatternApplicatorExtension>(); |
| 144 | assert(extension && |
| 145 | "expected PatternApplicatorExtension to be attached by the parent op" ); |
| 146 | SmallVector<Operation *> targets; |
| 147 | for (Operation *root : state.getPayloadOps(value: getRoot())) { |
| 148 | if (failed(Result: extension->findAllMatches( |
| 149 | patternName: getPatternName().getLeafReference().getValue(), root, results&: targets))) { |
| 150 | emitDefiniteFailure() |
| 151 | << "could not find pattern '" << getPatternName() << "'" ; |
| 152 | } |
| 153 | } |
| 154 | results.set(value: llvm::cast<OpResult>(Val: getResult()), ops&: targets); |
| 155 | return DiagnosedSilenceableFailure::success(); |
| 156 | } |
| 157 | |
| 158 | void transform::PDLMatchOp::getEffects( |
| 159 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
| 160 | onlyReadsHandle(handles: getRootMutable(), effects); |
| 161 | producesHandle(handles: getOperation()->getOpResults(), effects); |
| 162 | onlyReadsPayload(effects); |
| 163 | } |
| 164 | |
| 165 | //===----------------------------------------------------------------------===// |
| 166 | // WithPDLPatternsOp |
| 167 | //===----------------------------------------------------------------------===// |
| 168 | |
| 169 | DiagnosedSilenceableFailure |
| 170 | transform::WithPDLPatternsOp::apply(transform::TransformRewriter &rewriter, |
| 171 | transform::TransformResults &results, |
| 172 | transform::TransformState &state) { |
| 173 | TransformOpInterface transformOp = nullptr; |
| 174 | for (Operation &nested : getBody().front()) { |
| 175 | if (!isa<pdl::PatternOp>(Val: nested)) { |
| 176 | transformOp = cast<TransformOpInterface>(Val&: nested); |
| 177 | break; |
| 178 | } |
| 179 | } |
| 180 | |
| 181 | state.addExtension<PatternApplicatorExtension>(args: getOperation()); |
| 182 | auto guard = llvm::make_scope_exit( |
| 183 | F: [&]() { state.removeExtension<PatternApplicatorExtension>(); }); |
| 184 | |
| 185 | auto scope = state.make_region_scope(region&: getBody()); |
| 186 | if (failed(Result: mapBlockArguments(state))) |
| 187 | return DiagnosedSilenceableFailure::definiteFailure(); |
| 188 | return state.applyTransform(transform: transformOp); |
| 189 | } |
| 190 | |
| 191 | void transform::WithPDLPatternsOp::getEffects( |
| 192 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
| 193 | getPotentialTopLevelEffects(effects); |
| 194 | } |
| 195 | |
| 196 | LogicalResult transform::WithPDLPatternsOp::verify() { |
| 197 | Block *body = getBodyBlock(); |
| 198 | Operation *topLevelOp = nullptr; |
| 199 | for (Operation &op : body->getOperations()) { |
| 200 | if (isa<pdl::PatternOp>(Val: op)) |
| 201 | continue; |
| 202 | |
| 203 | if (op.hasTrait<::mlir::transform::PossibleTopLevelTransformOpTrait>()) { |
| 204 | if (topLevelOp) { |
| 205 | InFlightDiagnostic diag = |
| 206 | emitOpError() << "expects only one non-pattern op in its body" ; |
| 207 | diag.attachNote(noteLoc: topLevelOp->getLoc()) << "first non-pattern op" ; |
| 208 | diag.attachNote(noteLoc: op.getLoc()) << "second non-pattern op" ; |
| 209 | return diag; |
| 210 | } |
| 211 | topLevelOp = &op; |
| 212 | continue; |
| 213 | } |
| 214 | |
| 215 | InFlightDiagnostic diag = |
| 216 | emitOpError() |
| 217 | << "expects only pattern and top-level transform ops in its body" ; |
| 218 | diag.attachNote(noteLoc: op.getLoc()) << "offending op" ; |
| 219 | return diag; |
| 220 | } |
| 221 | |
| 222 | if (auto parent = getOperation()->getParentOfType<WithPDLPatternsOp>()) { |
| 223 | InFlightDiagnostic diag = emitOpError() << "cannot be nested" ; |
| 224 | diag.attachNote(noteLoc: parent.getLoc()) << "parent operation" ; |
| 225 | return diag; |
| 226 | } |
| 227 | |
| 228 | if (!topLevelOp) { |
| 229 | InFlightDiagnostic diag = emitOpError() |
| 230 | << "expects at least one non-pattern op" ; |
| 231 | return diag; |
| 232 | } |
| 233 | |
| 234 | return success(); |
| 235 | } |
| 236 | |