1//===- TestPDLByteCode.cpp - Test rewriter bytecode functionality ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
10#include "mlir/Pass/Pass.h"
11#include "mlir/Pass/PassManager.h"
12#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
13
14using namespace mlir;
15
16/// Custom constraint invoked from PDL.
17static LogicalResult customSingleEntityConstraint(PatternRewriter &rewriter,
18 Operation *rootOp) {
19 return success(isSuccess: rootOp->getName().getStringRef() == "test.op");
20}
21static LogicalResult customMultiEntityConstraint(PatternRewriter &rewriter,
22 Operation *root,
23 Operation *rootCopy) {
24 return customSingleEntityConstraint(rewriter, rootOp: rootCopy);
25}
26static LogicalResult customMultiEntityVariadicConstraint(
27 PatternRewriter &rewriter, ValueRange operandValues, TypeRange typeValues) {
28 if (operandValues.size() != 2 || typeValues.size() != 2)
29 return failure();
30 return success();
31}
32
33// Custom constraint that returns a value if the op is named test.success_op
34static LogicalResult customValueResultConstraint(PatternRewriter &rewriter,
35 PDLResultList &results,
36 ArrayRef<PDLValue> args) {
37 auto *op = args[0].cast<Operation *>();
38 if (op->getName().getStringRef() == "test.success_op") {
39 StringAttr customAttr = rewriter.getStringAttr("test.success");
40 results.push_back(customAttr);
41 return success();
42 }
43 return failure();
44}
45
46// Custom constraint that returns a type if the op is named test.success_op
47static LogicalResult customTypeResultConstraint(PatternRewriter &rewriter,
48 PDLResultList &results,
49 ArrayRef<PDLValue> args) {
50 auto *op = args[0].cast<Operation *>();
51 if (op->getName().getStringRef() == "test.success_op") {
52 results.push_back(value: rewriter.getF32Type());
53 return success();
54 }
55 return failure();
56}
57
58// Custom constraint that returns a type range of variable length if the op is
59// named test.success_op
60static LogicalResult customTypeRangeResultConstraint(PatternRewriter &rewriter,
61 PDLResultList &results,
62 ArrayRef<PDLValue> args) {
63 auto *op = args[0].cast<Operation *>();
64 int numTypes = cast<IntegerAttr>(args[1].cast<Attribute>()).getInt();
65
66 if (op->getName().getStringRef() == "test.success_op") {
67 SmallVector<Type> types;
68 for (int i = 0; i < numTypes; i++) {
69 types.push_back(Elt: rewriter.getF32Type());
70 }
71 results.push_back(value: TypeRange(types));
72 return success();
73 }
74 return failure();
75}
76
77// Custom creator invoked from PDL.
78static Operation *customCreate(PatternRewriter &rewriter, Operation *op) {
79 return rewriter.create(state: OperationState(op->getLoc(), "test.success"));
80}
81static auto customVariadicResultCreate(PatternRewriter &rewriter,
82 Operation *root) {
83 return std::make_pair(x: root->getOperands(), y: root->getOperands().getTypes());
84}
85static Type customCreateType(PatternRewriter &rewriter) {
86 return rewriter.getF32Type();
87}
88static std::string customCreateStrAttr(PatternRewriter &rewriter) {
89 return "test.str";
90}
91
92/// Custom rewriter invoked from PDL.
93static void customRewriter(PatternRewriter &rewriter, Operation *root,
94 Value input) {
95 rewriter.create(root->getLoc(), rewriter.getStringAttr("test.success"),
96 input);
97 rewriter.eraseOp(op: root);
98}
99
100namespace {
101struct TestPDLByteCodePass
102 : public PassWrapper<TestPDLByteCodePass, OperationPass<ModuleOp>> {
103 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPDLByteCodePass)
104
105 StringRef getArgument() const final { return "test-pdl-bytecode-pass"; }
106 StringRef getDescription() const final {
107 return "Test PDL ByteCode functionality";
108 }
109 void getDependentDialects(DialectRegistry &registry) const override {
110 // Mark the pdl_interp dialect as a dependent. This is needed, because we
111 // create ops from that dialect as a part of the PDL-to-PDLInterp lowering.
112 registry.insert<pdl_interp::PDLInterpDialect>();
113 }
114 void runOnOperation() final {
115 ModuleOp module = getOperation();
116
117 // The test cases are encompassed via two modules, one containing the
118 // patterns and one containing the operations to rewrite.
119 ModuleOp patternModule = module.lookupSymbol<ModuleOp>(
120 StringAttr::get(module->getContext(), "patterns"));
121 ModuleOp irModule = module.lookupSymbol<ModuleOp>(
122 StringAttr::get(module->getContext(), "ir"));
123 if (!patternModule || !irModule)
124 return;
125
126 RewritePatternSet patternList(module->getContext());
127
128 // Register ahead of time to test when functions are registered without a
129 // pattern.
130 patternList.getPDLPatterns().registerConstraintFunction(
131 name: "multi_entity_constraint", constraintFn&: customMultiEntityConstraint);
132 patternList.getPDLPatterns().registerConstraintFunction(
133 name: "single_entity_constraint", constraintFn&: customSingleEntityConstraint);
134
135 // Process the pattern module.
136 patternModule.getOperation()->remove();
137 PDLPatternModule pdlPattern(patternModule);
138
139 // Note: This constraint was already registered, but we re-register here to
140 // ensure that duplication registration is allowed (the duplicate mapping
141 // will be ignored). This tests that we support separating the registration
142 // of library functions from the construction of patterns, and also that we
143 // allow multiple patterns to depend on the same library functions (without
144 // asserting/crashing).
145 pdlPattern.registerConstraintFunction(name: "multi_entity_constraint",
146 constraintFn&: customMultiEntityConstraint);
147 pdlPattern.registerConstraintFunction(name: "multi_entity_var_constraint",
148 constraintFn&: customMultiEntityVariadicConstraint);
149 pdlPattern.registerConstraintFunction(name: "op_constr_return_attr",
150 constraintFn&: customValueResultConstraint);
151 pdlPattern.registerConstraintFunction(name: "op_constr_return_type",
152 constraintFn&: customTypeResultConstraint);
153 pdlPattern.registerConstraintFunction(name: "op_constr_return_type_range",
154 constraintFn&: customTypeRangeResultConstraint);
155 pdlPattern.registerRewriteFunction(name: "creator", rewriteFn&: customCreate);
156 pdlPattern.registerRewriteFunction(name: "var_creator",
157 rewriteFn&: customVariadicResultCreate);
158 pdlPattern.registerRewriteFunction(name: "type_creator", rewriteFn&: customCreateType);
159 pdlPattern.registerRewriteFunction(name: "str_creator", rewriteFn&: customCreateStrAttr);
160 pdlPattern.registerRewriteFunction(name: "rewriter", rewriteFn&: customRewriter);
161 patternList.add(pattern: std::move(pdlPattern));
162
163 // Invoke the pattern driver with the provided patterns.
164 (void)applyPatternsAndFoldGreedily(irModule.getBodyRegion(),
165 std::move(patternList));
166 }
167};
168} // namespace
169
170namespace mlir {
171namespace test {
172void registerTestPDLByteCodePass() { PassRegistration<TestPDLByteCodePass>(); }
173} // namespace test
174} // namespace mlir
175

source code of mlir/test/lib/Rewrite/TestPDLByteCode.cpp