1//===- AttrToLLVMConverter.h - Arith attributes conversion ------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_CONVERSION_ARITHCOMMON_ATTRTOLLVMCONVERTER_H
10#define MLIR_CONVERSION_ARITHCOMMON_ATTRTOLLVMCONVERTER_H
11
12#include "mlir/Dialect/Arith/IR/Arith.h"
13#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14
15//===----------------------------------------------------------------------===//
16// Support for converting Arith FastMathFlags to LLVM FastmathFlags
17//===----------------------------------------------------------------------===//
18
19namespace mlir {
20namespace arith {
21/// Maps arithmetic fastmath enum values to LLVM enum values.
22LLVM::FastmathFlags
23convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF);
24
25/// Creates an LLVM fastmath attribute from a given arithmetic fastmath
26/// attribute.
27LLVM::FastmathFlagsAttr
28convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr);
29
30/// Maps arithmetic overflow enum values to LLVM enum values.
31LLVM::IntegerOverflowFlags
32convertArithOverflowFlagsToLLVM(arith::IntegerOverflowFlags arithFlags);
33
34/// Creates an LLVM rounding mode enum value from a given arithmetic rounding
35/// mode enum value.
36LLVM::RoundingMode
37convertArithRoundingModeToLLVM(arith::RoundingMode roundingMode);
38
39/// Creates an LLVM rounding mode attribute from a given arithmetic rounding
40/// mode attribute.
41LLVM::RoundingModeAttr
42convertArithRoundingModeAttrToLLVM(arith::RoundingModeAttr roundingModeAttr);
43
44/// Returns an attribute for the default LLVM FP exception behavior.
45LLVM::FPExceptionBehaviorAttr
46getLLVMDefaultFPExceptionBehavior(MLIRContext &context);
47
48// Attribute converter that populates a NamedAttrList by removing the fastmath
49// attribute from the source operation attributes, and replacing it with an
50// equivalent LLVM fastmath attribute.
51template <typename SourceOp, typename TargetOp>
52class AttrConvertFastMathToLLVM {
53public:
54 AttrConvertFastMathToLLVM(SourceOp srcOp) {
55 // Copy the source attributes.
56 convertedAttr = NamedAttrList{srcOp->getAttrs()};
57 // Get the name of the arith fastmath attribute.
58 StringRef arithFMFAttrName = SourceOp::getFastMathAttrName();
59 // Remove the source fastmath attribute.
60 auto arithFMFAttr = dyn_cast_if_present<arith::FastMathFlagsAttr>(
61 convertedAttr.erase(arithFMFAttrName));
62 if (arithFMFAttr) {
63 StringRef targetAttrName = TargetOp::getFastmathAttrName();
64 convertedAttr.set(targetAttrName,
65 convertArithFastMathAttrToLLVM(arithFMFAttr));
66 }
67 }
68
69 ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
70 LLVM::IntegerOverflowFlags getOverflowFlags() const {
71 return LLVM::IntegerOverflowFlags::none;
72 }
73
74private:
75 NamedAttrList convertedAttr;
76};
77
78// Attribute converter that populates a NamedAttrList by removing the overflow
79// attribute from the source operation attributes, and replacing it with an
80// equivalent LLVM overflow attribute.
81template <typename SourceOp, typename TargetOp>
82class AttrConvertOverflowToLLVM {
83public:
84 AttrConvertOverflowToLLVM(SourceOp srcOp) {
85 // Copy the source attributes.
86 convertedAttr = NamedAttrList{srcOp->getAttrs()};
87 // Get the name of the arith overflow attribute.
88 StringRef arithAttrName = SourceOp::getIntegerOverflowAttrName();
89 // Remove the source overflow attribute.
90 if (auto arithAttr = dyn_cast_if_present<arith::IntegerOverflowFlagsAttr>(
91 convertedAttr.erase(arithAttrName))) {
92 overflowFlags = convertArithOverflowFlagsToLLVM(arithAttr.getValue());
93 }
94 }
95
96 ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
97 LLVM::IntegerOverflowFlags getOverflowFlags() const { return overflowFlags; }
98
99private:
100 NamedAttrList convertedAttr;
101 LLVM::IntegerOverflowFlags overflowFlags = LLVM::IntegerOverflowFlags::none;
102};
103
104template <typename SourceOp, typename TargetOp>
105class AttrConverterConstrainedFPToLLVM {
106 static_assert(TargetOp::template hasTrait<
107 LLVM::FPExceptionBehaviorOpInterface::Trait>(),
108 "Target constrained FP operations must implement "
109 "LLVM::FPExceptionBehaviorOpInterface");
110
111public:
112 AttrConverterConstrainedFPToLLVM(SourceOp srcOp) {
113 // Copy the source attributes.
114 convertedAttr = NamedAttrList{srcOp->getAttrs()};
115
116 if constexpr (TargetOp::template hasTrait<
117 LLVM::RoundingModeOpInterface::Trait>()) {
118 // Get the name of the rounding mode attribute.
119 StringRef arithAttrName = srcOp.getRoundingModeAttrName();
120 // Remove the source attribute.
121 auto arithAttr =
122 cast<arith::RoundingModeAttr>(convertedAttr.erase(arithAttrName));
123 // Set the target attribute.
124 convertedAttr.set(TargetOp::getRoundingModeAttrName(),
125 convertArithRoundingModeAttrToLLVM(arithAttr));
126 }
127 convertedAttr.set(TargetOp::getFPExceptionBehaviorAttrName(),
128 getLLVMDefaultFPExceptionBehavior(*srcOp->getContext()));
129 }
130
131 ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
132 LLVM::IntegerOverflowFlags getOverflowFlags() const {
133 return LLVM::IntegerOverflowFlags::none;
134 }
135
136private:
137 NamedAttrList convertedAttr;
138};
139
140} // namespace arith
141} // namespace mlir
142
143#endif // MLIR_CONVERSION_ARITHCOMMON_ATTRTOLLVMCONVERTER_H
144

source code of mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h