1 | //===- ComplexToSPIRV.cpp - Complex to SPIR-V Patterns --------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements patterns to convert Complex dialect to SPIR-V dialect. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "mlir/Conversion/ComplexToSPIRV/ComplexToSPIRV.h" |
14 | #include "mlir/Dialect/Complex/IR/Complex.h" |
15 | #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" |
16 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
17 | #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" |
18 | #include "mlir/Transforms/DialectConversion.h" |
19 | #include "llvm/Support/Debug.h" |
20 | |
21 | #define DEBUG_TYPE "complex-to-spirv-pattern" |
22 | |
23 | using namespace mlir; |
24 | |
25 | //===----------------------------------------------------------------------===// |
26 | // Operation conversion |
27 | //===----------------------------------------------------------------------===// |
28 | |
29 | namespace { |
30 | |
31 | struct ConstantOpPattern final : OpConversionPattern<complex::ConstantOp> { |
32 | using OpConversionPattern::OpConversionPattern; |
33 | |
34 | LogicalResult |
35 | matchAndRewrite(complex::ConstantOp constOp, OpAdaptor adaptor, |
36 | ConversionPatternRewriter &rewriter) const override { |
37 | auto spirvType = |
38 | getTypeConverter()->convertType<ShapedType>(constOp.getType()); |
39 | if (!spirvType) |
40 | return rewriter.notifyMatchFailure(constOp, |
41 | "unable to convert result type" ); |
42 | |
43 | rewriter.replaceOpWithNewOp<spirv::ConstantOp>( |
44 | constOp, spirvType, |
45 | DenseElementsAttr::get(spirvType, constOp.getValue().getValue())); |
46 | return success(); |
47 | } |
48 | }; |
49 | |
50 | struct CreateOpPattern final : OpConversionPattern<complex::CreateOp> { |
51 | using OpConversionPattern::OpConversionPattern; |
52 | |
53 | LogicalResult |
54 | matchAndRewrite(complex::CreateOp createOp, OpAdaptor adaptor, |
55 | ConversionPatternRewriter &rewriter) const override { |
56 | Type spirvType = getTypeConverter()->convertType(createOp.getType()); |
57 | if (!spirvType) |
58 | return rewriter.notifyMatchFailure(createOp, |
59 | "unable to convert result type" ); |
60 | |
61 | rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>( |
62 | createOp, spirvType, adaptor.getOperands()); |
63 | return success(); |
64 | } |
65 | }; |
66 | |
67 | struct ReOpPattern final : OpConversionPattern<complex::ReOp> { |
68 | using OpConversionPattern::OpConversionPattern; |
69 | |
70 | LogicalResult |
71 | matchAndRewrite(complex::ReOp reOp, OpAdaptor adaptor, |
72 | ConversionPatternRewriter &rewriter) const override { |
73 | Type spirvType = getTypeConverter()->convertType(reOp.getType()); |
74 | if (!spirvType) |
75 | return rewriter.notifyMatchFailure(reOp, "unable to convert result type" ); |
76 | |
77 | rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>( |
78 | reOp, adaptor.getComplex(), llvm::ArrayRef(0)); |
79 | return success(); |
80 | } |
81 | }; |
82 | |
83 | struct ImOpPattern final : OpConversionPattern<complex::ImOp> { |
84 | using OpConversionPattern::OpConversionPattern; |
85 | |
86 | LogicalResult |
87 | matchAndRewrite(complex::ImOp imOp, OpAdaptor adaptor, |
88 | ConversionPatternRewriter &rewriter) const override { |
89 | Type spirvType = getTypeConverter()->convertType(imOp.getType()); |
90 | if (!spirvType) |
91 | return rewriter.notifyMatchFailure(imOp, "unable to convert result type" ); |
92 | |
93 | rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>( |
94 | imOp, adaptor.getComplex(), llvm::ArrayRef(1)); |
95 | return success(); |
96 | } |
97 | }; |
98 | |
99 | } // namespace |
100 | |
101 | //===----------------------------------------------------------------------===// |
102 | // Pattern population |
103 | //===----------------------------------------------------------------------===// |
104 | |
105 | void mlir::populateComplexToSPIRVPatterns(SPIRVTypeConverter &typeConverter, |
106 | RewritePatternSet &patterns) { |
107 | MLIRContext *context = patterns.getContext(); |
108 | |
109 | patterns.add<ConstantOpPattern, CreateOpPattern, ReOpPattern, ImOpPattern>( |
110 | arg&: typeConverter, args&: context); |
111 | } |
112 | |