1 | //===- TestPDLByteCode.cpp - Test rewriter bytecode functionality ---------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" |
10 | #include "mlir/Pass/Pass.h" |
11 | #include "mlir/Pass/PassManager.h" |
12 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
13 | |
14 | using namespace mlir; |
15 | |
16 | /// Custom constraint invoked from PDL. |
17 | static LogicalResult customSingleEntityConstraint(PatternRewriter &rewriter, |
18 | Operation *rootOp) { |
19 | return success(IsSuccess: rootOp->getName().getStringRef() == "test.op"); |
20 | } |
21 | static LogicalResult customMultiEntityConstraint(PatternRewriter &rewriter, |
22 | Operation *root, |
23 | Operation *rootCopy) { |
24 | return customSingleEntityConstraint(rewriter, rootOp: rootCopy); |
25 | } |
26 | static LogicalResult customMultiEntityVariadicConstraint( |
27 | PatternRewriter &rewriter, ValueRange operandValues, TypeRange typeValues) { |
28 | if (operandValues.size() != 2 || typeValues.size() != 2) |
29 | return failure(); |
30 | return success(); |
31 | } |
32 | |
33 | // Custom constraint that returns a value if the op is named test.success_op |
34 | static LogicalResult customValueResultConstraint(PatternRewriter &rewriter, |
35 | PDLResultList &results, |
36 | ArrayRef<PDLValue> args) { |
37 | auto *op = args[0].cast<Operation *>(); |
38 | if (op->getName().getStringRef() == "test.success_op") { |
39 | StringAttr customAttr = rewriter.getStringAttr("test.success"); |
40 | results.push_back(customAttr); |
41 | return success(); |
42 | } |
43 | return failure(); |
44 | } |
45 | |
46 | // Custom constraint that returns a type if the op is named test.success_op |
47 | static LogicalResult customTypeResultConstraint(PatternRewriter &rewriter, |
48 | PDLResultList &results, |
49 | ArrayRef<PDLValue> args) { |
50 | auto *op = args[0].cast<Operation *>(); |
51 | if (op->getName().getStringRef() == "test.success_op") { |
52 | results.push_back(rewriter.getF32Type()); |
53 | return success(); |
54 | } |
55 | return failure(); |
56 | } |
57 | |
58 | // Custom constraint that always returns failure |
59 | static LogicalResult customConstraintFailure(PatternRewriter & /*rewriter*/, |
60 | PDLResultList & /*results*/, |
61 | ArrayRef<PDLValue> /*args*/) { |
62 | return failure(); |
63 | } |
64 | |
65 | // Custom constraint that returns a type range of variable length if the op is |
66 | // named test.success_op |
67 | static LogicalResult customTypeRangeResultConstraint(PatternRewriter &rewriter, |
68 | PDLResultList &results, |
69 | ArrayRef<PDLValue> args) { |
70 | auto *op = args[0].cast<Operation *>(); |
71 | int numTypes = cast<IntegerAttr>(args[1].cast<Attribute>()).getInt(); |
72 | |
73 | if (op->getName().getStringRef() == "test.success_op") { |
74 | SmallVector<Type> types; |
75 | for (int i = 0; i < numTypes; i++) { |
76 | types.push_back(rewriter.getF32Type()); |
77 | } |
78 | results.push_back(value: TypeRange(types)); |
79 | return success(); |
80 | } |
81 | return failure(); |
82 | } |
83 | |
84 | // Custom creator invoked from PDL. |
85 | static Operation *customCreate(PatternRewriter &rewriter, Operation *op) { |
86 | return rewriter.create(state: OperationState(op->getLoc(), "test.success")); |
87 | } |
88 | static auto customVariadicResultCreate(PatternRewriter &rewriter, |
89 | Operation *root) { |
90 | return std::make_pair(x: root->getOperands(), y: root->getOperands().getTypes()); |
91 | } |
92 | static Type customCreateType(PatternRewriter &rewriter) { |
93 | return rewriter.getF32Type(); |
94 | } |
95 | static std::string customCreateStrAttr(PatternRewriter &rewriter) { |
96 | return "test.str"; |
97 | } |
98 | |
99 | /// Custom rewriter invoked from PDL. |
100 | static void customRewriter(PatternRewriter &rewriter, Operation *root, |
101 | Value input) { |
102 | rewriter.create(root->getLoc(), rewriter.getStringAttr("test.success"), |
103 | input); |
104 | rewriter.eraseOp(op: root); |
105 | } |
106 | |
107 | namespace { |
108 | struct TestPDLByteCodePass |
109 | : public PassWrapper<TestPDLByteCodePass, OperationPass<ModuleOp>> { |
110 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPDLByteCodePass) |
111 | |
112 | StringRef getArgument() const final { return "test-pdl-bytecode-pass"; } |
113 | StringRef getDescription() const final { |
114 | return "Test PDL ByteCode functionality"; |
115 | } |
116 | void getDependentDialects(DialectRegistry ®istry) const override { |
117 | // Mark the pdl_interp dialect as a dependent. This is needed, because we |
118 | // create ops from that dialect as a part of the PDL-to-PDLInterp lowering. |
119 | registry.insert<pdl_interp::PDLInterpDialect>(); |
120 | } |
121 | void runOnOperation() final { |
122 | ModuleOp module = getOperation(); |
123 | |
124 | // The test cases are encompassed via two modules, one containing the |
125 | // patterns and one containing the operations to rewrite. |
126 | ModuleOp patternModule = module.lookupSymbol<ModuleOp>( |
127 | StringAttr::get(module->getContext(), "patterns")); |
128 | ModuleOp irModule = module.lookupSymbol<ModuleOp>( |
129 | StringAttr::get(module->getContext(), "ir")); |
130 | if (!patternModule || !irModule) |
131 | return; |
132 | |
133 | RewritePatternSet patternList(module->getContext()); |
134 | |
135 | // Register ahead of time to test when functions are registered without a |
136 | // pattern. |
137 | patternList.getPDLPatterns().registerConstraintFunction( |
138 | name: "multi_entity_constraint", constraintFn&: customMultiEntityConstraint); |
139 | patternList.getPDLPatterns().registerConstraintFunction( |
140 | name: "single_entity_constraint", constraintFn&: customSingleEntityConstraint); |
141 | |
142 | // Process the pattern module. |
143 | patternModule.getOperation()->remove(); |
144 | PDLPatternModule pdlPattern(patternModule); |
145 | |
146 | // Note: This constraint was already registered, but we re-register here to |
147 | // ensure that duplication registration is allowed (the duplicate mapping |
148 | // will be ignored). This tests that we support separating the registration |
149 | // of library functions from the construction of patterns, and also that we |
150 | // allow multiple patterns to depend on the same library functions (without |
151 | // asserting/crashing). |
152 | pdlPattern.registerConstraintFunction(name: "multi_entity_constraint", |
153 | constraintFn&: customMultiEntityConstraint); |
154 | pdlPattern.registerConstraintFunction(name: "multi_entity_var_constraint", |
155 | constraintFn&: customMultiEntityVariadicConstraint); |
156 | pdlPattern.registerConstraintFunction(name: "op_constr_return_attr", |
157 | constraintFn&: customValueResultConstraint); |
158 | pdlPattern.registerConstraintFunction(name: "op_constr_return_type", |
159 | constraintFn&: customTypeResultConstraint); |
160 | pdlPattern.registerConstraintFunction(name: "op_multiple_returns_failure", |
161 | constraintFn&: customConstraintFailure); |
162 | pdlPattern.registerConstraintFunction(name: "op_constr_return_type_range", |
163 | constraintFn&: customTypeRangeResultConstraint); |
164 | pdlPattern.registerRewriteFunction(name: "creator", rewriteFn&: customCreate); |
165 | pdlPattern.registerRewriteFunction(name: "var_creator", |
166 | rewriteFn&: customVariadicResultCreate); |
167 | pdlPattern.registerRewriteFunction(name: "type_creator", rewriteFn&: customCreateType); |
168 | pdlPattern.registerRewriteFunction(name: "str_creator", rewriteFn&: customCreateStrAttr); |
169 | pdlPattern.registerRewriteFunction(name: "rewriter", rewriteFn&: customRewriter); |
170 | patternList.add(pattern: std::move(pdlPattern)); |
171 | |
172 | // Invoke the pattern driver with the provided patterns. |
173 | (void)applyPatternsGreedily(irModule.getBodyRegion(), |
174 | std::move(patternList)); |
175 | } |
176 | }; |
177 | } // namespace |
178 | |
179 | namespace mlir { |
180 | namespace test { |
181 | void registerTestPDLByteCodePass() { PassRegistration<TestPDLByteCodePass>(); } |
182 | } // namespace test |
183 | } // namespace mlir |
184 |
Definitions
- customSingleEntityConstraint
- customMultiEntityConstraint
- customMultiEntityVariadicConstraint
- customValueResultConstraint
- customTypeResultConstraint
- customConstraintFailure
- customTypeRangeResultConstraint
- customCreate
- customVariadicResultCreate
- customCreateType
- customCreateStrAttr
- customRewriter
- TestPDLByteCodePass
- getArgument
- getDescription
- getDependentDialects
- runOnOperation
Improve your Profiling and Debugging skills
Find out more