1 | //===- AttrToLLVMConverter.h - Arith attributes conversion ------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef MLIR_CONVERSION_ARITHCOMMON_ATTRTOLLVMCONVERTER_H |
10 | #define MLIR_CONVERSION_ARITHCOMMON_ATTRTOLLVMCONVERTER_H |
11 | |
12 | #include "mlir/Dialect/Arith/IR/Arith.h" |
13 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
14 | |
15 | //===----------------------------------------------------------------------===// |
16 | // Support for converting Arith FastMathFlags to LLVM FastmathFlags |
17 | //===----------------------------------------------------------------------===// |
18 | |
19 | namespace mlir { |
20 | namespace arith { |
21 | /// Maps arithmetic fastmath enum values to LLVM enum values. |
22 | LLVM::FastmathFlags |
23 | convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF); |
24 | |
25 | /// Creates an LLVM fastmath attribute from a given arithmetic fastmath |
26 | /// attribute. |
27 | LLVM::FastmathFlagsAttr |
28 | convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr); |
29 | |
30 | /// Maps arithmetic overflow enum values to LLVM enum values. |
31 | LLVM::IntegerOverflowFlags |
32 | convertArithOverflowFlagsToLLVM(arith::IntegerOverflowFlags arithFlags); |
33 | |
34 | /// Creates an LLVM rounding mode enum value from a given arithmetic rounding |
35 | /// mode enum value. |
36 | LLVM::RoundingMode |
37 | convertArithRoundingModeToLLVM(arith::RoundingMode roundingMode); |
38 | |
39 | /// Creates an LLVM rounding mode attribute from a given arithmetic rounding |
40 | /// mode attribute. |
41 | LLVM::RoundingModeAttr |
42 | convertArithRoundingModeAttrToLLVM(arith::RoundingModeAttr roundingModeAttr); |
43 | |
44 | /// Returns an attribute for the default LLVM FP exception behavior. |
45 | LLVM::FPExceptionBehaviorAttr |
46 | getLLVMDefaultFPExceptionBehavior(MLIRContext &context); |
47 | |
48 | // Attribute converter that populates a NamedAttrList by removing the fastmath |
49 | // attribute from the source operation attributes, and replacing it with an |
50 | // equivalent LLVM fastmath attribute. |
51 | template <typename SourceOp, typename TargetOp> |
52 | class AttrConvertFastMathToLLVM { |
53 | public: |
54 | AttrConvertFastMathToLLVM(SourceOp srcOp) { |
55 | // Copy the source attributes. |
56 | convertedAttr = NamedAttrList{srcOp->getAttrs()}; |
57 | // Get the name of the arith fastmath attribute. |
58 | StringRef arithFMFAttrName = SourceOp::getFastMathAttrName(); |
59 | // Remove the source fastmath attribute. |
60 | auto arithFMFAttr = dyn_cast_if_present<arith::FastMathFlagsAttr>( |
61 | convertedAttr.erase(arithFMFAttrName)); |
62 | if (arithFMFAttr) { |
63 | StringRef targetAttrName = TargetOp::getFastmathAttrName(); |
64 | convertedAttr.set(targetAttrName, |
65 | convertArithFastMathAttrToLLVM(arithFMFAttr)); |
66 | } |
67 | } |
68 | |
69 | ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); } |
70 | LLVM::IntegerOverflowFlags getOverflowFlags() const { |
71 | return LLVM::IntegerOverflowFlags::none; |
72 | } |
73 | |
74 | private: |
75 | NamedAttrList convertedAttr; |
76 | }; |
77 | |
78 | // Attribute converter that populates a NamedAttrList by removing the overflow |
79 | // attribute from the source operation attributes, and replacing it with an |
80 | // equivalent LLVM overflow attribute. |
81 | template <typename SourceOp, typename TargetOp> |
82 | class AttrConvertOverflowToLLVM { |
83 | public: |
84 | AttrConvertOverflowToLLVM(SourceOp srcOp) { |
85 | // Copy the source attributes. |
86 | convertedAttr = NamedAttrList{srcOp->getAttrs()}; |
87 | // Get the name of the arith overflow attribute. |
88 | StringRef arithAttrName = SourceOp::getIntegerOverflowAttrName(); |
89 | // Remove the source overflow attribute. |
90 | if (auto arithAttr = dyn_cast_if_present<arith::IntegerOverflowFlagsAttr>( |
91 | convertedAttr.erase(arithAttrName))) { |
92 | overflowFlags = convertArithOverflowFlagsToLLVM(arithAttr.getValue()); |
93 | } |
94 | } |
95 | |
96 | ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); } |
97 | LLVM::IntegerOverflowFlags getOverflowFlags() const { return overflowFlags; } |
98 | |
99 | private: |
100 | NamedAttrList convertedAttr; |
101 | LLVM::IntegerOverflowFlags overflowFlags = LLVM::IntegerOverflowFlags::none; |
102 | }; |
103 | |
104 | template <typename SourceOp, typename TargetOp> |
105 | class AttrConverterConstrainedFPToLLVM { |
106 | static_assert(TargetOp::template hasTrait< |
107 | LLVM::FPExceptionBehaviorOpInterface::Trait>(), |
108 | "Target constrained FP operations must implement " |
109 | "LLVM::FPExceptionBehaviorOpInterface" ); |
110 | |
111 | public: |
112 | AttrConverterConstrainedFPToLLVM(SourceOp srcOp) { |
113 | // Copy the source attributes. |
114 | convertedAttr = NamedAttrList{srcOp->getAttrs()}; |
115 | |
116 | if constexpr (TargetOp::template hasTrait< |
117 | LLVM::RoundingModeOpInterface::Trait>()) { |
118 | // Get the name of the rounding mode attribute. |
119 | StringRef arithAttrName = srcOp.getRoundingModeAttrName(); |
120 | // Remove the source attribute. |
121 | auto arithAttr = |
122 | cast<arith::RoundingModeAttr>(convertedAttr.erase(arithAttrName)); |
123 | // Set the target attribute. |
124 | convertedAttr.set(TargetOp::getRoundingModeAttrName(), |
125 | convertArithRoundingModeAttrToLLVM(arithAttr)); |
126 | } |
127 | convertedAttr.set(TargetOp::getFPExceptionBehaviorAttrName(), |
128 | getLLVMDefaultFPExceptionBehavior(*srcOp->getContext())); |
129 | } |
130 | |
131 | ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); } |
132 | LLVM::IntegerOverflowFlags getOverflowFlags() const { |
133 | return LLVM::IntegerOverflowFlags::none; |
134 | } |
135 | |
136 | private: |
137 | NamedAttrList convertedAttr; |
138 | }; |
139 | |
140 | } // namespace arith |
141 | } // namespace mlir |
142 | |
143 | #endif // MLIR_CONVERSION_ARITHCOMMON_ATTRTOLLVMCONVERTER_H |
144 | |