1 | //===- TestMemRefToLLVMWithTransforms.cpp ---------------------------------===// |
---|---|
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" |
10 | #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" |
11 | #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" |
12 | #include "mlir/Conversion/LLVMCommon/TypeConverter.h" |
13 | #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" |
14 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
15 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
16 | #include "mlir/Dialect/MemRef/Transforms/Transforms.h" |
17 | #include "mlir/IR/PatternMatch.h" |
18 | #include "mlir/Pass/Pass.h" |
19 | |
20 | using namespace mlir; |
21 | |
22 | namespace { |
23 | |
24 | struct TestMemRefToLLVMWithTransforms |
25 | : public PassWrapper<TestMemRefToLLVMWithTransforms, |
26 | OperationPass<ModuleOp>> { |
27 | MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMemRefToLLVMWithTransforms) |
28 | |
29 | void getDependentDialects(DialectRegistry ®istry) const final { |
30 | registry.insert<LLVM::LLVMDialect>(); |
31 | } |
32 | |
33 | StringRef getArgument() const final { |
34 | return "test-memref-to-llvm-with-transforms"; |
35 | } |
36 | |
37 | StringRef getDescription() const final { |
38 | return "Tests conversion of MemRef dialects + `func.func` to LLVM dialect " |
39 | "with MemRef transforms."; |
40 | } |
41 | |
42 | void runOnOperation() override { |
43 | MLIRContext *ctx = &getContext(); |
44 | LowerToLLVMOptions options(ctx); |
45 | LLVMTypeConverter typeConverter(ctx, options); |
46 | RewritePatternSet patterns(ctx); |
47 | memref::populateExpandStridedMetadataPatterns(patterns); |
48 | populateFuncToLLVMConversionPatterns(converter: typeConverter, patterns); |
49 | LLVMConversionTarget target(getContext()); |
50 | if (failed(applyPartialConversion(getOperation(), target, |
51 | std::move(patterns)))) |
52 | signalPassFailure(); |
53 | } |
54 | }; |
55 | |
56 | } // namespace |
57 | |
58 | namespace mlir::test { |
59 | void registerTestMemRefToLLVMWithTransforms() { |
60 | PassRegistration<TestMemRefToLLVMWithTransforms>(); |
61 | } |
62 | } // namespace mlir::test |
63 |