1//===- TestPDLByteCode.cpp - Test rewriter bytecode functionality ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
10#include "mlir/Pass/Pass.h"
11#include "mlir/Pass/PassManager.h"
12#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
13
14using namespace mlir;
15
16/// Custom constraint invoked from PDL.
17static LogicalResult customSingleEntityConstraint(PatternRewriter &rewriter,
18 Operation *rootOp) {
19 return success(IsSuccess: rootOp->getName().getStringRef() == "test.op");
20}
21static LogicalResult customMultiEntityConstraint(PatternRewriter &rewriter,
22 Operation *root,
23 Operation *rootCopy) {
24 return customSingleEntityConstraint(rewriter, rootOp: rootCopy);
25}
26static LogicalResult customMultiEntityVariadicConstraint(
27 PatternRewriter &rewriter, ValueRange operandValues, TypeRange typeValues) {
28 if (operandValues.size() != 2 || typeValues.size() != 2)
29 return failure();
30 return success();
31}
32
33// Custom constraint that returns a value if the op is named test.success_op
34static LogicalResult customValueResultConstraint(PatternRewriter &rewriter,
35 PDLResultList &results,
36 ArrayRef<PDLValue> args) {
37 auto *op = args[0].cast<Operation *>();
38 if (op->getName().getStringRef() == "test.success_op") {
39 StringAttr customAttr = rewriter.getStringAttr("test.success");
40 results.push_back(customAttr);
41 return success();
42 }
43 return failure();
44}
45
46// Custom constraint that returns a type if the op is named test.success_op
47static LogicalResult customTypeResultConstraint(PatternRewriter &rewriter,
48 PDLResultList &results,
49 ArrayRef<PDLValue> args) {
50 auto *op = args[0].cast<Operation *>();
51 if (op->getName().getStringRef() == "test.success_op") {
52 results.push_back(rewriter.getF32Type());
53 return success();
54 }
55 return failure();
56}
57
58// Custom constraint that always returns failure
59static LogicalResult customConstraintFailure(PatternRewriter & /*rewriter*/,
60 PDLResultList & /*results*/,
61 ArrayRef<PDLValue> /*args*/) {
62 return failure();
63}
64
65// Custom constraint that returns a type range of variable length if the op is
66// named test.success_op
67static LogicalResult customTypeRangeResultConstraint(PatternRewriter &rewriter,
68 PDLResultList &results,
69 ArrayRef<PDLValue> args) {
70 auto *op = args[0].cast<Operation *>();
71 int numTypes = cast<IntegerAttr>(args[1].cast<Attribute>()).getInt();
72
73 if (op->getName().getStringRef() == "test.success_op") {
74 SmallVector<Type> types;
75 for (int i = 0; i < numTypes; i++) {
76 types.push_back(rewriter.getF32Type());
77 }
78 results.push_back(value: TypeRange(types));
79 return success();
80 }
81 return failure();
82}
83
84// Custom creator invoked from PDL.
85static Operation *customCreate(PatternRewriter &rewriter, Operation *op) {
86 return rewriter.create(state: OperationState(op->getLoc(), "test.success"));
87}
88static auto customVariadicResultCreate(PatternRewriter &rewriter,
89 Operation *root) {
90 return std::make_pair(x: root->getOperands(), y: root->getOperands().getTypes());
91}
92static Type customCreateType(PatternRewriter &rewriter) {
93 return rewriter.getF32Type();
94}
95static std::string customCreateStrAttr(PatternRewriter &rewriter) {
96 return "test.str";
97}
98
99/// Custom rewriter invoked from PDL.
100static void customRewriter(PatternRewriter &rewriter, Operation *root,
101 Value input) {
102 rewriter.create(root->getLoc(), rewriter.getStringAttr("test.success"),
103 input);
104 rewriter.eraseOp(op: root);
105}
106
107namespace {
108struct TestPDLByteCodePass
109 : public PassWrapper<TestPDLByteCodePass, OperationPass<ModuleOp>> {
110 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPDLByteCodePass)
111
112 StringRef getArgument() const final { return "test-pdl-bytecode-pass"; }
113 StringRef getDescription() const final {
114 return "Test PDL ByteCode functionality";
115 }
116 void getDependentDialects(DialectRegistry &registry) const override {
117 // Mark the pdl_interp dialect as a dependent. This is needed, because we
118 // create ops from that dialect as a part of the PDL-to-PDLInterp lowering.
119 registry.insert<pdl_interp::PDLInterpDialect>();
120 }
121 void runOnOperation() final {
122 ModuleOp module = getOperation();
123
124 // The test cases are encompassed via two modules, one containing the
125 // patterns and one containing the operations to rewrite.
126 ModuleOp patternModule = module.lookupSymbol<ModuleOp>(
127 StringAttr::get(module->getContext(), "patterns"));
128 ModuleOp irModule = module.lookupSymbol<ModuleOp>(
129 StringAttr::get(module->getContext(), "ir"));
130 if (!patternModule || !irModule)
131 return;
132
133 RewritePatternSet patternList(module->getContext());
134
135 // Register ahead of time to test when functions are registered without a
136 // pattern.
137 patternList.getPDLPatterns().registerConstraintFunction(
138 name: "multi_entity_constraint", constraintFn&: customMultiEntityConstraint);
139 patternList.getPDLPatterns().registerConstraintFunction(
140 name: "single_entity_constraint", constraintFn&: customSingleEntityConstraint);
141
142 // Process the pattern module.
143 patternModule.getOperation()->remove();
144 PDLPatternModule pdlPattern(patternModule);
145
146 // Note: This constraint was already registered, but we re-register here to
147 // ensure that duplication registration is allowed (the duplicate mapping
148 // will be ignored). This tests that we support separating the registration
149 // of library functions from the construction of patterns, and also that we
150 // allow multiple patterns to depend on the same library functions (without
151 // asserting/crashing).
152 pdlPattern.registerConstraintFunction(name: "multi_entity_constraint",
153 constraintFn&: customMultiEntityConstraint);
154 pdlPattern.registerConstraintFunction(name: "multi_entity_var_constraint",
155 constraintFn&: customMultiEntityVariadicConstraint);
156 pdlPattern.registerConstraintFunction(name: "op_constr_return_attr",
157 constraintFn&: customValueResultConstraint);
158 pdlPattern.registerConstraintFunction(name: "op_constr_return_type",
159 constraintFn&: customTypeResultConstraint);
160 pdlPattern.registerConstraintFunction(name: "op_multiple_returns_failure",
161 constraintFn&: customConstraintFailure);
162 pdlPattern.registerConstraintFunction(name: "op_constr_return_type_range",
163 constraintFn&: customTypeRangeResultConstraint);
164 pdlPattern.registerRewriteFunction(name: "creator", rewriteFn&: customCreate);
165 pdlPattern.registerRewriteFunction(name: "var_creator",
166 rewriteFn&: customVariadicResultCreate);
167 pdlPattern.registerRewriteFunction(name: "type_creator", rewriteFn&: customCreateType);
168 pdlPattern.registerRewriteFunction(name: "str_creator", rewriteFn&: customCreateStrAttr);
169 pdlPattern.registerRewriteFunction(name: "rewriter", rewriteFn&: customRewriter);
170 patternList.add(pattern: std::move(pdlPattern));
171
172 // Invoke the pattern driver with the provided patterns.
173 (void)applyPatternsGreedily(irModule.getBodyRegion(),
174 std::move(patternList));
175 }
176};
177} // namespace
178
179namespace mlir {
180namespace test {
181void registerTestPDLByteCodePass() { PassRegistration<TestPDLByteCodePass>(); }
182} // namespace test
183} // namespace mlir
184

Provided by KDAB

Privacy Policy
Improve your Profiling and Debugging skills
Find out more

source code of mlir/test/lib/Rewrite/TestPDLByteCode.cpp