1 | //===- TestDialectConversion.cpp - Test DialectConversion functionality ---===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "TestDialect.h" |
10 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
11 | #include "mlir/Dialect/PDL/IR/PDL.h" |
12 | #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" |
13 | #include "mlir/Parser/Parser.h" |
14 | #include "mlir/Pass/Pass.h" |
15 | #include "mlir/Pass/PassManager.h" |
16 | #include "mlir/Transforms/DialectConversion.h" |
17 | |
18 | using namespace mlir; |
19 | using namespace test; |
20 | |
21 | //===----------------------------------------------------------------------===// |
22 | // Test PDLL Support |
23 | //===----------------------------------------------------------------------===// |
24 | |
25 | #include "TestDialectConversionPDLLPatterns.h.inc" |
26 | |
27 | namespace { |
28 | struct PDLLTypeConverter : public TypeConverter { |
29 | PDLLTypeConverter() { |
30 | addConversion(convertType); |
31 | addArgumentMaterialization(materializeCast); |
32 | addSourceMaterialization(materializeCast); |
33 | } |
34 | |
35 | static LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) { |
36 | // Convert I64 to F64. |
37 | if (t.isSignlessInteger(width: 64)) { |
38 | results.push_back(Elt: FloatType::getF64(ctx: t.getContext())); |
39 | return success(); |
40 | } |
41 | |
42 | // Otherwise, convert the type directly. |
43 | results.push_back(Elt: t); |
44 | return success(); |
45 | } |
46 | /// Hook for materializing a conversion. |
47 | static std::optional<Value> materializeCast(OpBuilder &builder, |
48 | Type resultType, |
49 | ValueRange inputs, Location loc) { |
50 | return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs) |
51 | .getResult(0); |
52 | } |
53 | }; |
54 | |
55 | struct TestDialectConversionPDLLPass |
56 | : public PassWrapper<TestDialectConversionPDLLPass, OperationPass<>> { |
57 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDialectConversionPDLLPass) |
58 | |
59 | StringRef getArgument() const final { return "test-dialect-conversion-pdll" ; } |
60 | StringRef getDescription() const final { |
61 | return "Test DialectConversion PDLL functionality" ; |
62 | } |
63 | void getDependentDialects(DialectRegistry ®istry) const override { |
64 | registry.insert<pdl::PDLDialect, pdl_interp::PDLInterpDialect>(); |
65 | } |
66 | LogicalResult initialize(MLIRContext *ctx) override { |
67 | // Build the pattern set within the `initialize` to avoid recompiling PDL |
68 | // patterns during each `runOnOperation` invocation. |
69 | RewritePatternSet patternList(ctx); |
70 | registerConversionPDLFunctions(patterns&: patternList); |
71 | populateGeneratedPDLLPatterns(patternList, PDLConversionConfig(&converter)); |
72 | patterns = std::move(patternList); |
73 | return success(); |
74 | } |
75 | |
76 | void runOnOperation() final { |
77 | mlir::ConversionTarget target(getContext()); |
78 | target.addLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>(); |
79 | target.addDynamicallyLegalDialect<TestDialect>( |
80 | [this](Operation *op) { return converter.isLegal(op); }); |
81 | |
82 | if (failed(mlir::applyFullConversion(getOperation(), target, patterns))) |
83 | signalPassFailure(); |
84 | } |
85 | |
86 | FrozenRewritePatternSet patterns; |
87 | PDLLTypeConverter converter; |
88 | }; |
89 | } // namespace |
90 | |
91 | namespace mlir { |
92 | namespace test { |
93 | void registerTestDialectConversionPasses() { |
94 | PassRegistration<TestDialectConversionPDLLPass>(); |
95 | } |
96 | } // namespace test |
97 | } // namespace mlir |
98 | |