1 | //===- ComplexDialect.cpp - MLIR Complex Dialect --------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" |
10 | #include "mlir/Dialect/Arith/IR/Arith.h" |
11 | #include "mlir/Dialect/Complex/IR/Complex.h" |
12 | #include "mlir/IR/Builders.h" |
13 | #include "mlir/IR/DialectImplementation.h" |
14 | #include "mlir/Transforms/InliningUtils.h" |
15 | #include "llvm/ADT/StringExtras.h" |
16 | #include "llvm/ADT/TypeSwitch.h" |
17 | |
18 | using namespace mlir; |
19 | |
20 | #include "mlir/Dialect/Complex/IR/ComplexOpsDialect.cpp.inc" |
21 | |
22 | namespace { |
23 | /// This class defines the interface for handling inlining for complex |
24 | /// dialect operations. |
25 | struct ComplexInlinerInterface : public DialectInlinerInterface { |
26 | using DialectInlinerInterface::DialectInlinerInterface; |
27 | /// All complex dialect ops can be inlined. |
28 | bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final { |
29 | return true; |
30 | } |
31 | }; |
32 | } // namespace |
33 | |
34 | void complex::ComplexDialect::initialize() { |
35 | addOperations< |
36 | #define GET_OP_LIST |
37 | #include "mlir/Dialect/Complex/IR/ComplexOps.cpp.inc" |
38 | >(); |
39 | addAttributes< |
40 | #define GET_ATTRDEF_LIST |
41 | #include "mlir/Dialect/Complex/IR/ComplexAttributes.cpp.inc" |
42 | >(); |
43 | declarePromisedInterface<ConvertToLLVMPatternInterface, ComplexDialect>(); |
44 | addInterfaces<ComplexInlinerInterface>(); |
45 | } |
46 | |
47 | Operation *complex::ComplexDialect::materializeConstant(OpBuilder &builder, |
48 | Attribute value, |
49 | Type type, |
50 | Location loc) { |
51 | if (complex::ConstantOp::isBuildableWith(value, type)) { |
52 | return builder.create<complex::ConstantOp>(loc, type, |
53 | llvm::cast<ArrayAttr>(value)); |
54 | } |
55 | return arith::ConstantOp::materialize(builder, value, type, loc); |
56 | } |
57 | |
58 | #define GET_ATTRDEF_CLASSES |
59 | #include "mlir/Dialect/Complex/IR/ComplexAttributes.cpp.inc" |
60 | |
61 | LogicalResult complex::NumberAttr::verify( |
62 | ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, |
63 | ::llvm::APFloat real, ::llvm::APFloat imag, ::mlir::Type type) { |
64 | |
65 | if (!llvm::isa<ComplexType>(type)) |
66 | return emitError() << "complex attribute must be a complex type." ; |
67 | |
68 | Type elementType = llvm::cast<ComplexType>(type).getElementType(); |
69 | if (!llvm::isa<FloatType>(elementType)) |
70 | return emitError() |
71 | << "element type of the complex attribute must be float like type." ; |
72 | |
73 | const auto &typeFloatSemantics = |
74 | llvm::cast<FloatType>(elementType).getFloatSemantics(); |
75 | if (&real.getSemantics() != &typeFloatSemantics) |
76 | return emitError() |
77 | << "type doesn't match the type implied by its `real` value" ; |
78 | if (&imag.getSemantics() != &typeFloatSemantics) |
79 | return emitError() |
80 | << "type doesn't match the type implied by its `imag` value" ; |
81 | |
82 | return success(); |
83 | } |
84 | |
85 | void complex::NumberAttr::print(AsmPrinter &printer) const { |
86 | printer << "<:" << llvm::cast<ComplexType>(getType()).getElementType() << " " |
87 | << getReal() << ", " << getImag() << ">" ; |
88 | } |
89 | |
90 | Attribute complex::NumberAttr::parse(AsmParser &parser, Type odsType) { |
91 | Type type; |
92 | double real, imag; |
93 | if (parser.parseLess() || parser.parseColon() || parser.parseType(type) || |
94 | parser.parseFloat(real) || parser.parseComma() || |
95 | parser.parseFloat(imag) || parser.parseGreater()) |
96 | return {}; |
97 | |
98 | return NumberAttr::get(ComplexType::get(type), real, imag); |
99 | } |
100 | |