| 1 | //===- PDLExtensionOps.cpp - PDL extension for the Transform dialect ------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h" |
| 10 | #include "mlir/Dialect/PDL/IR/PDLOps.h" |
| 11 | #include "mlir/IR/Builders.h" |
| 12 | #include "mlir/IR/OpImplementation.h" |
| 13 | #include "mlir/Rewrite/FrozenRewritePatternSet.h" |
| 14 | #include "mlir/Rewrite/PatternApplicator.h" |
| 15 | #include "llvm/ADT/ScopeExit.h" |
| 16 | |
| 17 | using namespace mlir; |
| 18 | |
| 19 | MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::transform::PDLMatchHooks) |
| 20 | |
| 21 | #define GET_OP_CLASSES |
| 22 | #include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp.inc" |
| 23 | |
| 24 | //===----------------------------------------------------------------------===// |
| 25 | // PatternApplicatorExtension |
| 26 | //===----------------------------------------------------------------------===// |
| 27 | |
| 28 | namespace { |
| 29 | /// A TransformState extension that keeps track of compiled PDL pattern sets. |
| 30 | /// This is intended to be used along the WithPDLPatterns op. The extension |
| 31 | /// can be constructed given an operation that has a SymbolTable trait and |
| 32 | /// contains pdl::PatternOp instances. The patterns are compiled lazily and one |
| 33 | /// by one when requested; this behavior is subject to change. |
| 34 | class PatternApplicatorExtension : public transform::TransformState::Extension { |
| 35 | public: |
| 36 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PatternApplicatorExtension) |
| 37 | |
| 38 | /// Creates the extension for patterns contained in `patternContainer`. |
| 39 | explicit PatternApplicatorExtension(transform::TransformState &state, |
| 40 | Operation *patternContainer) |
| 41 | : Extension(state), patterns(patternContainer) {} |
| 42 | |
| 43 | /// Appends to `results` the operations contained in `root` that matched the |
| 44 | /// PDL pattern with the given name. Note that `root` may or may not be the |
| 45 | /// operation that contains PDL patterns. Reports an error if the pattern |
| 46 | /// cannot be found. Note that when no operations are matched, this still |
| 47 | /// succeeds as long as the pattern exists. |
| 48 | LogicalResult findAllMatches(StringRef patternName, Operation *root, |
| 49 | SmallVectorImpl<Operation *> &results); |
| 50 | |
| 51 | private: |
| 52 | /// Map from the pattern name to a singleton set of rewrite patterns that only |
| 53 | /// contains the pattern with this name. Populated when the pattern is first |
| 54 | /// requested. |
| 55 | // TODO: reconsider the efficiency of this storage when more usage data is |
| 56 | // available. Storing individual patterns in a set and triggering compilation |
| 57 | // for each of them has overhead. So does compiling a large set of patterns |
| 58 | // only to apply a handful of them. |
| 59 | llvm::StringMap<FrozenRewritePatternSet> compiledPatterns; |
| 60 | |
| 61 | /// A symbol table operation containing the relevant PDL patterns. |
| 62 | SymbolTable patterns; |
| 63 | }; |
| 64 | |
| 65 | LogicalResult PatternApplicatorExtension::findAllMatches( |
| 66 | StringRef patternName, Operation *root, |
| 67 | SmallVectorImpl<Operation *> &results) { |
| 68 | auto it = compiledPatterns.find(patternName); |
| 69 | if (it == compiledPatterns.end()) { |
| 70 | auto patternOp = patterns.lookup<pdl::PatternOp>(patternName); |
| 71 | if (!patternOp) |
| 72 | return failure(); |
| 73 | |
| 74 | // Copy the pattern operation into a new module that is compiled and |
| 75 | // consumed by the PDL interpreter. |
| 76 | OwningOpRef<ModuleOp> pdlModuleOp = ModuleOp::create(patternOp.getLoc()); |
| 77 | auto builder = OpBuilder::atBlockEnd(pdlModuleOp->getBody()); |
| 78 | builder.clone(*patternOp); |
| 79 | PDLPatternModule patternModule(std::move(pdlModuleOp)); |
| 80 | |
| 81 | // Merge in the hooks owned by the dialect. Make a copy as they may be |
| 82 | // also used by the following operations. |
| 83 | auto *dialect = |
| 84 | root->getContext()->getLoadedDialect<transform::TransformDialect>(); |
| 85 | for (const auto &[name, constraintFn] : |
| 86 | dialect->getExtraData<transform::PDLMatchHooks>() |
| 87 | .getPDLConstraintHooks()) { |
| 88 | patternModule.registerConstraintFunction(name, constraintFn); |
| 89 | } |
| 90 | |
| 91 | // Register a noop rewriter because PDL requires patterns to end with some |
| 92 | // rewrite call. |
| 93 | patternModule.registerRewriteFunction( |
| 94 | "transform.dialect" , [](PatternRewriter &, Operation *) {}); |
| 95 | |
| 96 | it = compiledPatterns |
| 97 | .try_emplace(patternOp.getName(), std::move(patternModule)) |
| 98 | .first; |
| 99 | } |
| 100 | |
| 101 | PatternApplicator applicator(it->second); |
| 102 | // We want to discourage direct use of PatternRewriter in APIs but In this |
| 103 | // very specific case, an IRRewriter is not enough. |
| 104 | struct TrivialPatternRewriter : public PatternRewriter { |
| 105 | public: |
| 106 | explicit TrivialPatternRewriter(MLIRContext *context) |
| 107 | : PatternRewriter(context) {} |
| 108 | }; |
| 109 | TrivialPatternRewriter rewriter(root->getContext()); |
| 110 | applicator.applyDefaultCostModel(); |
| 111 | root->walk([&](Operation *op) { |
| 112 | if (succeeded(Result: applicator.matchAndRewrite(op, rewriter))) |
| 113 | results.push_back(Elt: op); |
| 114 | }); |
| 115 | |
| 116 | return success(); |
| 117 | } |
| 118 | } // namespace |
| 119 | |
| 120 | //===----------------------------------------------------------------------===// |
| 121 | // PDLMatchHooks |
| 122 | //===----------------------------------------------------------------------===// |
| 123 | |
| 124 | void transform::PDLMatchHooks::mergeInPDLMatchHooks( |
| 125 | llvm::StringMap<PDLConstraintFunction> &&constraintFns) { |
| 126 | // Steal the constraint functions from the given map. |
| 127 | for (auto &it : constraintFns) |
| 128 | pdlMatchHooks.registerConstraintFunction(it.getKey(), std::move(it.second)); |
| 129 | } |
| 130 | |
| 131 | const llvm::StringMap<PDLConstraintFunction> & |
| 132 | transform::PDLMatchHooks::getPDLConstraintHooks() const { |
| 133 | return pdlMatchHooks.getConstraintFunctions(); |
| 134 | } |
| 135 | |
| 136 | //===----------------------------------------------------------------------===// |
| 137 | // PDLMatchOp |
| 138 | //===----------------------------------------------------------------------===// |
| 139 | |
| 140 | DiagnosedSilenceableFailure |
| 141 | transform::PDLMatchOp::apply(transform::TransformRewriter &rewriter, |
| 142 | transform::TransformResults &results, |
| 143 | transform::TransformState &state) { |
| 144 | auto *extension = state.getExtension<PatternApplicatorExtension>(); |
| 145 | assert(extension && |
| 146 | "expected PatternApplicatorExtension to be attached by the parent op" ); |
| 147 | SmallVector<Operation *> targets; |
| 148 | for (Operation *root : state.getPayloadOps(getRoot())) { |
| 149 | if (failed(extension->findAllMatches( |
| 150 | getPatternName().getLeafReference().getValue(), root, targets))) { |
| 151 | emitDefiniteFailure() |
| 152 | << "could not find pattern '" << getPatternName() << "'" ; |
| 153 | } |
| 154 | } |
| 155 | results.set(llvm::cast<OpResult>(getResult()), targets); |
| 156 | return DiagnosedSilenceableFailure::success(); |
| 157 | } |
| 158 | |
| 159 | void transform::PDLMatchOp::getEffects( |
| 160 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
| 161 | onlyReadsHandle(getRootMutable(), effects); |
| 162 | producesHandle(getOperation()->getOpResults(), effects); |
| 163 | onlyReadsPayload(effects); |
| 164 | } |
| 165 | |
| 166 | //===----------------------------------------------------------------------===// |
| 167 | // WithPDLPatternsOp |
| 168 | //===----------------------------------------------------------------------===// |
| 169 | |
| 170 | DiagnosedSilenceableFailure |
| 171 | transform::WithPDLPatternsOp::apply(transform::TransformRewriter &rewriter, |
| 172 | transform::TransformResults &results, |
| 173 | transform::TransformState &state) { |
| 174 | TransformOpInterface transformOp = nullptr; |
| 175 | for (Operation &nested : getBody().front()) { |
| 176 | if (!isa<pdl::PatternOp>(nested)) { |
| 177 | transformOp = cast<TransformOpInterface>(nested); |
| 178 | break; |
| 179 | } |
| 180 | } |
| 181 | |
| 182 | state.addExtension<PatternApplicatorExtension>(getOperation()); |
| 183 | auto guard = llvm::make_scope_exit( |
| 184 | [&]() { state.removeExtension<PatternApplicatorExtension>(); }); |
| 185 | |
| 186 | auto scope = state.make_region_scope(getBody()); |
| 187 | if (failed(mapBlockArguments(state))) |
| 188 | return DiagnosedSilenceableFailure::definiteFailure(); |
| 189 | return state.applyTransform(transformOp); |
| 190 | } |
| 191 | |
| 192 | void transform::WithPDLPatternsOp::getEffects( |
| 193 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
| 194 | getPotentialTopLevelEffects(effects); |
| 195 | } |
| 196 | |
| 197 | LogicalResult transform::WithPDLPatternsOp::verify() { |
| 198 | Block *body = getBodyBlock(); |
| 199 | Operation *topLevelOp = nullptr; |
| 200 | for (Operation &op : body->getOperations()) { |
| 201 | if (isa<pdl::PatternOp>(op)) |
| 202 | continue; |
| 203 | |
| 204 | if (op.hasTrait<::mlir::transform::PossibleTopLevelTransformOpTrait>()) { |
| 205 | if (topLevelOp) { |
| 206 | InFlightDiagnostic diag = |
| 207 | emitOpError() << "expects only one non-pattern op in its body" ; |
| 208 | diag.attachNote(topLevelOp->getLoc()) << "first non-pattern op" ; |
| 209 | diag.attachNote(op.getLoc()) << "second non-pattern op" ; |
| 210 | return diag; |
| 211 | } |
| 212 | topLevelOp = &op; |
| 213 | continue; |
| 214 | } |
| 215 | |
| 216 | InFlightDiagnostic diag = |
| 217 | emitOpError() |
| 218 | << "expects only pattern and top-level transform ops in its body" ; |
| 219 | diag.attachNote(op.getLoc()) << "offending op" ; |
| 220 | return diag; |
| 221 | } |
| 222 | |
| 223 | if (auto parent = getOperation()->getParentOfType<WithPDLPatternsOp>()) { |
| 224 | InFlightDiagnostic diag = emitOpError() << "cannot be nested" ; |
| 225 | diag.attachNote(parent.getLoc()) << "parent operation" ; |
| 226 | return diag; |
| 227 | } |
| 228 | |
| 229 | if (!topLevelOp) { |
| 230 | InFlightDiagnostic diag = emitOpError() |
| 231 | << "expects at least one non-pattern op" ; |
| 232 | return diag; |
| 233 | } |
| 234 | |
| 235 | return success(); |
| 236 | } |
| 237 | |