1 | //===- TestDialectConversion.cpp - Test DialectConversion functionality ---===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "TestDialect.h" |
10 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
11 | #include "mlir/Dialect/PDL/IR/PDL.h" |
12 | #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" |
13 | #include "mlir/Parser/Parser.h" |
14 | #include "mlir/Pass/Pass.h" |
15 | #include "mlir/Pass/PassManager.h" |
16 | #include "mlir/Transforms/DialectConversion.h" |
17 | |
18 | using namespace mlir; |
19 | using namespace test; |
20 | |
21 | //===----------------------------------------------------------------------===// |
22 | // Test PDLL Support |
23 | //===----------------------------------------------------------------------===// |
24 | |
25 | #include "TestDialectConversionPDLLPatterns.h.inc" |
26 | |
27 | namespace { |
28 | struct PDLLTypeConverter : public TypeConverter { |
29 | PDLLTypeConverter() { |
30 | addConversion(convertType); |
31 | addSourceMaterialization(materializeCast); |
32 | } |
33 | |
34 | static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) { |
35 | // Convert I64 to F64. |
36 | if (t.isSignlessInteger(width: 64)) { |
37 | results.push_back(Float64Type::get(t.getContext())); |
38 | return success(); |
39 | } |
40 | |
41 | // Otherwise, convert the type directly. |
42 | results.push_back(Elt: t); |
43 | return success(); |
44 | } |
45 | /// Hook for materializing a conversion. |
46 | static Value materializeCast(OpBuilder &builder, Type resultType, |
47 | ValueRange inputs, Location loc) { |
48 | return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs) |
49 | .getResult(0); |
50 | } |
51 | }; |
52 | |
53 | struct TestDialectConversionPDLLPass |
54 | : public PassWrapper<TestDialectConversionPDLLPass, OperationPass<>> { |
55 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialectConversionPDLLPass) |
56 | |
57 | StringRef getArgument() const final { return "test-dialect-conversion-pdll" ; } |
58 | StringRef getDescription() const final { |
59 | return "Test DialectConversion PDLL functionality" ; |
60 | } |
61 | void getDependentDialects(DialectRegistry ®istry) const override { |
62 | registry.insert<pdl::PDLDialect, pdl_interp::PDLInterpDialect>(); |
63 | } |
64 | LogicalResult initialize(MLIRContext *ctx) override { |
65 | // Build the pattern set within the `initialize` to avoid recompiling PDL |
66 | // patterns during each `runOnOperation` invocation. |
67 | RewritePatternSet patternList(ctx); |
68 | registerConversionPDLFunctions(patterns&: patternList); |
69 | populateGeneratedPDLLPatterns(patternList, PDLConversionConfig(&converter)); |
70 | patterns = std::move(patternList); |
71 | return success(); |
72 | } |
73 | |
74 | void runOnOperation() final { |
75 | mlir::ConversionTarget target(getContext()); |
76 | target.addLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>(); |
77 | target.addDynamicallyLegalDialect<TestDialect>( |
78 | [this](Operation *op) { return converter.isLegal(op); }); |
79 | |
80 | if (failed(mlir::applyFullConversion(getOperation(), target, patterns))) |
81 | signalPassFailure(); |
82 | } |
83 | |
84 | FrozenRewritePatternSet patterns; |
85 | PDLLTypeConverter converter; |
86 | }; |
87 | } // namespace |
88 | |
89 | namespace mlir { |
90 | namespace test { |
91 | void registerTestDialectConversionPasses() { |
92 | PassRegistration<TestDialectConversionPDLLPass>(); |
93 | } |
94 | } // namespace test |
95 | } // namespace mlir |
96 | |