1//===- ComplexToSPIRV.cpp - Complex to SPIR-V Patterns --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert Complex dialect to SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/ComplexToSPIRV/ComplexToSPIRV.h"
14#include "mlir/Dialect/Complex/IR/Complex.h"
15#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
16#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
17#include "mlir/Transforms/DialectConversion.h"
18
19#define DEBUG_TYPE "complex-to-spirv-pattern"
20
21using namespace mlir;
22
23//===----------------------------------------------------------------------===//
24// Operation conversion
25//===----------------------------------------------------------------------===//
26
27namespace {
28
29struct ConstantOpPattern final : OpConversionPattern<complex::ConstantOp> {
30 using OpConversionPattern::OpConversionPattern;
31
32 LogicalResult
33 matchAndRewrite(complex::ConstantOp constOp, OpAdaptor adaptor,
34 ConversionPatternRewriter &rewriter) const override {
35 auto spirvType =
36 getTypeConverter()->convertType<ShapedType>(t: constOp.getType());
37 if (!spirvType)
38 return rewriter.notifyMatchFailure(arg&: constOp,
39 msg: "unable to convert result type");
40
41 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
42 op: constOp, args&: spirvType,
43 args: DenseElementsAttr::get(type: spirvType, values: constOp.getValue().getValue()));
44 return success();
45 }
46};
47
48struct CreateOpPattern final : OpConversionPattern<complex::CreateOp> {
49 using OpConversionPattern::OpConversionPattern;
50
51 LogicalResult
52 matchAndRewrite(complex::CreateOp createOp, OpAdaptor adaptor,
53 ConversionPatternRewriter &rewriter) const override {
54 Type spirvType = getTypeConverter()->convertType(t: createOp.getType());
55 if (!spirvType)
56 return rewriter.notifyMatchFailure(arg&: createOp,
57 msg: "unable to convert result type");
58
59 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
60 op: createOp, args&: spirvType, args: adaptor.getOperands());
61 return success();
62 }
63};
64
65struct ReOpPattern final : OpConversionPattern<complex::ReOp> {
66 using OpConversionPattern::OpConversionPattern;
67
68 LogicalResult
69 matchAndRewrite(complex::ReOp reOp, OpAdaptor adaptor,
70 ConversionPatternRewriter &rewriter) const override {
71 Type spirvType = getTypeConverter()->convertType(t: reOp.getType());
72 if (!spirvType)
73 return rewriter.notifyMatchFailure(arg&: reOp, msg: "unable to convert result type");
74
75 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
76 op: reOp, args: adaptor.getComplex(), args: llvm::ArrayRef(0));
77 return success();
78 }
79};
80
81struct ImOpPattern final : OpConversionPattern<complex::ImOp> {
82 using OpConversionPattern::OpConversionPattern;
83
84 LogicalResult
85 matchAndRewrite(complex::ImOp imOp, OpAdaptor adaptor,
86 ConversionPatternRewriter &rewriter) const override {
87 Type spirvType = getTypeConverter()->convertType(t: imOp.getType());
88 if (!spirvType)
89 return rewriter.notifyMatchFailure(arg&: imOp, msg: "unable to convert result type");
90
91 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
92 op: imOp, args: adaptor.getComplex(), args: llvm::ArrayRef(1));
93 return success();
94 }
95};
96
97} // namespace
98
99//===----------------------------------------------------------------------===//
100// Pattern population
101//===----------------------------------------------------------------------===//
102
103void mlir::populateComplexToSPIRVPatterns(
104 const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
105 MLIRContext *context = patterns.getContext();
106
107 patterns.add<ConstantOpPattern, CreateOpPattern, ReOpPattern, ImOpPattern>(
108 arg: typeConverter, args&: context);
109}
110

source code of mlir/lib/Conversion/ComplexToSPIRV/ComplexToSPIRV.cpp