1//===- ComplexToSPIRV.cpp - Complex to SPIR-V Patterns --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert Complex dialect to SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Conversion/ComplexToSPIRV/ComplexToSPIRV.h"
14#include "mlir/Dialect/Complex/IR/Complex.h"
15#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
16#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
18#include "mlir/Transforms/DialectConversion.h"
19#include "llvm/Support/Debug.h"
20
21#define DEBUG_TYPE "complex-to-spirv-pattern"
22
23using namespace mlir;
24
25//===----------------------------------------------------------------------===//
26// Operation conversion
27//===----------------------------------------------------------------------===//
28
29namespace {
30
31struct ConstantOpPattern final : OpConversionPattern<complex::ConstantOp> {
32 using OpConversionPattern::OpConversionPattern;
33
34 LogicalResult
35 matchAndRewrite(complex::ConstantOp constOp, OpAdaptor adaptor,
36 ConversionPatternRewriter &rewriter) const override {
37 auto spirvType =
38 getTypeConverter()->convertType<ShapedType>(constOp.getType());
39 if (!spirvType)
40 return rewriter.notifyMatchFailure(constOp,
41 "unable to convert result type");
42
43 rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
44 constOp, spirvType,
45 DenseElementsAttr::get(spirvType, constOp.getValue().getValue()));
46 return success();
47 }
48};
49
50struct CreateOpPattern final : OpConversionPattern<complex::CreateOp> {
51 using OpConversionPattern::OpConversionPattern;
52
53 LogicalResult
54 matchAndRewrite(complex::CreateOp createOp, OpAdaptor adaptor,
55 ConversionPatternRewriter &rewriter) const override {
56 Type spirvType = getTypeConverter()->convertType(createOp.getType());
57 if (!spirvType)
58 return rewriter.notifyMatchFailure(createOp,
59 "unable to convert result type");
60
61 rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
62 createOp, spirvType, adaptor.getOperands());
63 return success();
64 }
65};
66
67struct ReOpPattern final : OpConversionPattern<complex::ReOp> {
68 using OpConversionPattern::OpConversionPattern;
69
70 LogicalResult
71 matchAndRewrite(complex::ReOp reOp, OpAdaptor adaptor,
72 ConversionPatternRewriter &rewriter) const override {
73 Type spirvType = getTypeConverter()->convertType(reOp.getType());
74 if (!spirvType)
75 return rewriter.notifyMatchFailure(reOp, "unable to convert result type");
76
77 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
78 reOp, adaptor.getComplex(), llvm::ArrayRef(0));
79 return success();
80 }
81};
82
83struct ImOpPattern final : OpConversionPattern<complex::ImOp> {
84 using OpConversionPattern::OpConversionPattern;
85
86 LogicalResult
87 matchAndRewrite(complex::ImOp imOp, OpAdaptor adaptor,
88 ConversionPatternRewriter &rewriter) const override {
89 Type spirvType = getTypeConverter()->convertType(imOp.getType());
90 if (!spirvType)
91 return rewriter.notifyMatchFailure(imOp, "unable to convert result type");
92
93 rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
94 imOp, adaptor.getComplex(), llvm::ArrayRef(1));
95 return success();
96 }
97};
98
99} // namespace
100
101//===----------------------------------------------------------------------===//
102// Pattern population
103//===----------------------------------------------------------------------===//
104
105void mlir::populateComplexToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
106 RewritePatternSet &patterns) {
107 MLIRContext *context = patterns.getContext();
108
109 patterns.add<ConstantOpPattern, CreateOpPattern, ReOpPattern, ImOpPattern>(
110 arg&: typeConverter, args&: context);
111}
112

source code of mlir/lib/Conversion/ComplexToSPIRV/ComplexToSPIRV.cpp