1 | //===- TestPDLByteCode.cpp - Test rewriter bytecode functionality ---------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" |
10 | #include "mlir/Pass/Pass.h" |
11 | #include "mlir/Pass/PassManager.h" |
12 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
13 | |
14 | using namespace mlir; |
15 | |
16 | /// Custom constraint invoked from PDL. |
17 | static LogicalResult customSingleEntityConstraint(PatternRewriter &rewriter, |
18 | Operation *rootOp) { |
19 | return success(isSuccess: rootOp->getName().getStringRef() == "test.op" ); |
20 | } |
21 | static LogicalResult customMultiEntityConstraint(PatternRewriter &rewriter, |
22 | Operation *root, |
23 | Operation *rootCopy) { |
24 | return customSingleEntityConstraint(rewriter, rootOp: rootCopy); |
25 | } |
26 | static LogicalResult customMultiEntityVariadicConstraint( |
27 | PatternRewriter &rewriter, ValueRange operandValues, TypeRange typeValues) { |
28 | if (operandValues.size() != 2 || typeValues.size() != 2) |
29 | return failure(); |
30 | return success(); |
31 | } |
32 | |
33 | // Custom constraint that returns a value if the op is named test.success_op |
34 | static LogicalResult customValueResultConstraint(PatternRewriter &rewriter, |
35 | PDLResultList &results, |
36 | ArrayRef<PDLValue> args) { |
37 | auto *op = args[0].cast<Operation *>(); |
38 | if (op->getName().getStringRef() == "test.success_op" ) { |
39 | StringAttr customAttr = rewriter.getStringAttr("test.success" ); |
40 | results.push_back(customAttr); |
41 | return success(); |
42 | } |
43 | return failure(); |
44 | } |
45 | |
46 | // Custom constraint that returns a type if the op is named test.success_op |
47 | static LogicalResult customTypeResultConstraint(PatternRewriter &rewriter, |
48 | PDLResultList &results, |
49 | ArrayRef<PDLValue> args) { |
50 | auto *op = args[0].cast<Operation *>(); |
51 | if (op->getName().getStringRef() == "test.success_op" ) { |
52 | results.push_back(value: rewriter.getF32Type()); |
53 | return success(); |
54 | } |
55 | return failure(); |
56 | } |
57 | |
58 | // Custom constraint that returns a type range of variable length if the op is |
59 | // named test.success_op |
60 | static LogicalResult customTypeRangeResultConstraint(PatternRewriter &rewriter, |
61 | PDLResultList &results, |
62 | ArrayRef<PDLValue> args) { |
63 | auto *op = args[0].cast<Operation *>(); |
64 | int numTypes = cast<IntegerAttr>(args[1].cast<Attribute>()).getInt(); |
65 | |
66 | if (op->getName().getStringRef() == "test.success_op" ) { |
67 | SmallVector<Type> types; |
68 | for (int i = 0; i < numTypes; i++) { |
69 | types.push_back(Elt: rewriter.getF32Type()); |
70 | } |
71 | results.push_back(value: TypeRange(types)); |
72 | return success(); |
73 | } |
74 | return failure(); |
75 | } |
76 | |
77 | // Custom creator invoked from PDL. |
78 | static Operation *customCreate(PatternRewriter &rewriter, Operation *op) { |
79 | return rewriter.create(state: OperationState(op->getLoc(), "test.success" )); |
80 | } |
81 | static auto customVariadicResultCreate(PatternRewriter &rewriter, |
82 | Operation *root) { |
83 | return std::make_pair(x: root->getOperands(), y: root->getOperands().getTypes()); |
84 | } |
85 | static Type customCreateType(PatternRewriter &rewriter) { |
86 | return rewriter.getF32Type(); |
87 | } |
88 | static std::string customCreateStrAttr(PatternRewriter &rewriter) { |
89 | return "test.str" ; |
90 | } |
91 | |
92 | /// Custom rewriter invoked from PDL. |
93 | static void customRewriter(PatternRewriter &rewriter, Operation *root, |
94 | Value input) { |
95 | rewriter.create(root->getLoc(), rewriter.getStringAttr("test.success" ), |
96 | input); |
97 | rewriter.eraseOp(op: root); |
98 | } |
99 | |
100 | namespace { |
101 | struct TestPDLByteCodePass |
102 | : public PassWrapper<TestPDLByteCodePass, OperationPass<ModuleOp>> { |
103 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPDLByteCodePass) |
104 | |
105 | StringRef getArgument() const final { return "test-pdl-bytecode-pass" ; } |
106 | StringRef getDescription() const final { |
107 | return "Test PDL ByteCode functionality" ; |
108 | } |
109 | void getDependentDialects(DialectRegistry ®istry) const override { |
110 | // Mark the pdl_interp dialect as a dependent. This is needed, because we |
111 | // create ops from that dialect as a part of the PDL-to-PDLInterp lowering. |
112 | registry.insert<pdl_interp::PDLInterpDialect>(); |
113 | } |
114 | void runOnOperation() final { |
115 | ModuleOp module = getOperation(); |
116 | |
117 | // The test cases are encompassed via two modules, one containing the |
118 | // patterns and one containing the operations to rewrite. |
119 | ModuleOp patternModule = module.lookupSymbol<ModuleOp>( |
120 | StringAttr::get(module->getContext(), "patterns" )); |
121 | ModuleOp irModule = module.lookupSymbol<ModuleOp>( |
122 | StringAttr::get(module->getContext(), "ir" )); |
123 | if (!patternModule || !irModule) |
124 | return; |
125 | |
126 | RewritePatternSet patternList(module->getContext()); |
127 | |
128 | // Register ahead of time to test when functions are registered without a |
129 | // pattern. |
130 | patternList.getPDLPatterns().registerConstraintFunction( |
131 | name: "multi_entity_constraint" , constraintFn&: customMultiEntityConstraint); |
132 | patternList.getPDLPatterns().registerConstraintFunction( |
133 | name: "single_entity_constraint" , constraintFn&: customSingleEntityConstraint); |
134 | |
135 | // Process the pattern module. |
136 | patternModule.getOperation()->remove(); |
137 | PDLPatternModule pdlPattern(patternModule); |
138 | |
139 | // Note: This constraint was already registered, but we re-register here to |
140 | // ensure that duplication registration is allowed (the duplicate mapping |
141 | // will be ignored). This tests that we support separating the registration |
142 | // of library functions from the construction of patterns, and also that we |
143 | // allow multiple patterns to depend on the same library functions (without |
144 | // asserting/crashing). |
145 | pdlPattern.registerConstraintFunction(name: "multi_entity_constraint" , |
146 | constraintFn&: customMultiEntityConstraint); |
147 | pdlPattern.registerConstraintFunction(name: "multi_entity_var_constraint" , |
148 | constraintFn&: customMultiEntityVariadicConstraint); |
149 | pdlPattern.registerConstraintFunction(name: "op_constr_return_attr" , |
150 | constraintFn&: customValueResultConstraint); |
151 | pdlPattern.registerConstraintFunction(name: "op_constr_return_type" , |
152 | constraintFn&: customTypeResultConstraint); |
153 | pdlPattern.registerConstraintFunction(name: "op_constr_return_type_range" , |
154 | constraintFn&: customTypeRangeResultConstraint); |
155 | pdlPattern.registerRewriteFunction(name: "creator" , rewriteFn&: customCreate); |
156 | pdlPattern.registerRewriteFunction(name: "var_creator" , |
157 | rewriteFn&: customVariadicResultCreate); |
158 | pdlPattern.registerRewriteFunction(name: "type_creator" , rewriteFn&: customCreateType); |
159 | pdlPattern.registerRewriteFunction(name: "str_creator" , rewriteFn&: customCreateStrAttr); |
160 | pdlPattern.registerRewriteFunction(name: "rewriter" , rewriteFn&: customRewriter); |
161 | patternList.add(pattern: std::move(pdlPattern)); |
162 | |
163 | // Invoke the pattern driver with the provided patterns. |
164 | (void)applyPatternsAndFoldGreedily(irModule.getBodyRegion(), |
165 | std::move(patternList)); |
166 | } |
167 | }; |
168 | } // namespace |
169 | |
170 | namespace mlir { |
171 | namespace test { |
172 | void registerTestPDLByteCodePass() { PassRegistration<TestPDLByteCodePass>(); } |
173 | } // namespace test |
174 | } // namespace mlir |
175 | |