| 1 | //===- AttrToLLVMConverter.h - Arith attributes conversion ------*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #ifndef MLIR_CONVERSION_ARITHCOMMON_ATTRTOLLVMCONVERTER_H |
| 10 | #define MLIR_CONVERSION_ARITHCOMMON_ATTRTOLLVMCONVERTER_H |
| 11 | |
| 12 | #include "mlir/Dialect/Arith/IR/Arith.h" |
| 13 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" |
| 14 | |
| 15 | //===----------------------------------------------------------------------===// |
| 16 | // Support for converting Arith FastMathFlags to LLVM FastmathFlags |
| 17 | //===----------------------------------------------------------------------===// |
| 18 | |
| 19 | namespace mlir { |
| 20 | namespace arith { |
| 21 | /// Maps arithmetic fastmath enum values to LLVM enum values. |
| 22 | LLVM::FastmathFlags |
| 23 | convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF); |
| 24 | |
| 25 | /// Creates an LLVM fastmath attribute from a given arithmetic fastmath |
| 26 | /// attribute. |
| 27 | LLVM::FastmathFlagsAttr |
| 28 | convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr); |
| 29 | |
| 30 | /// Maps arithmetic overflow enum values to LLVM enum values. |
| 31 | LLVM::IntegerOverflowFlags |
| 32 | convertArithOverflowFlagsToLLVM(arith::IntegerOverflowFlags arithFlags); |
| 33 | |
| 34 | /// Creates an LLVM rounding mode enum value from a given arithmetic rounding |
| 35 | /// mode enum value. |
| 36 | LLVM::RoundingMode |
| 37 | convertArithRoundingModeToLLVM(arith::RoundingMode roundingMode); |
| 38 | |
| 39 | /// Creates an LLVM rounding mode attribute from a given arithmetic rounding |
| 40 | /// mode attribute. |
| 41 | LLVM::RoundingModeAttr |
| 42 | convertArithRoundingModeAttrToLLVM(arith::RoundingModeAttr roundingModeAttr); |
| 43 | |
| 44 | /// Returns an attribute for the default LLVM FP exception behavior. |
| 45 | LLVM::FPExceptionBehaviorAttr |
| 46 | getLLVMDefaultFPExceptionBehavior(MLIRContext &context); |
| 47 | |
| 48 | // Attribute converter that populates a NamedAttrList by removing the fastmath |
| 49 | // attribute from the source operation attributes, and replacing it with an |
| 50 | // equivalent LLVM fastmath attribute. |
| 51 | template <typename SourceOp, typename TargetOp> |
| 52 | class AttrConvertFastMathToLLVM { |
| 53 | public: |
| 54 | AttrConvertFastMathToLLVM(SourceOp srcOp) { |
| 55 | // Copy the source attributes. |
| 56 | convertedAttr = NamedAttrList{srcOp->getAttrs()}; |
| 57 | // Get the name of the arith fastmath attribute. |
| 58 | StringRef arithFMFAttrName = SourceOp::getFastMathAttrName(); |
| 59 | // Remove the source fastmath attribute. |
| 60 | auto arithFMFAttr = dyn_cast_if_present<arith::FastMathFlagsAttr>( |
| 61 | convertedAttr.erase(arithFMFAttrName)); |
| 62 | if (arithFMFAttr) { |
| 63 | StringRef targetAttrName = TargetOp::getFastmathAttrName(); |
| 64 | convertedAttr.set(targetAttrName, |
| 65 | convertArithFastMathAttrToLLVM(arithFMFAttr)); |
| 66 | } |
| 67 | } |
| 68 | |
| 69 | ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); } |
| 70 | LLVM::IntegerOverflowFlags getOverflowFlags() const { |
| 71 | return LLVM::IntegerOverflowFlags::none; |
| 72 | } |
| 73 | |
| 74 | private: |
| 75 | NamedAttrList convertedAttr; |
| 76 | }; |
| 77 | |
| 78 | // Attribute converter that populates a NamedAttrList by removing the overflow |
| 79 | // attribute from the source operation attributes, and replacing it with an |
| 80 | // equivalent LLVM overflow attribute. |
| 81 | template <typename SourceOp, typename TargetOp> |
| 82 | class AttrConvertOverflowToLLVM { |
| 83 | public: |
| 84 | AttrConvertOverflowToLLVM(SourceOp srcOp) { |
| 85 | // Copy the source attributes. |
| 86 | convertedAttr = NamedAttrList{srcOp->getAttrs()}; |
| 87 | // Get the name of the arith overflow attribute. |
| 88 | StringRef arithAttrName = SourceOp::getIntegerOverflowAttrName(); |
| 89 | // Remove the source overflow attribute. |
| 90 | if (auto arithAttr = dyn_cast_if_present<arith::IntegerOverflowFlagsAttr>( |
| 91 | convertedAttr.erase(arithAttrName))) { |
| 92 | overflowFlags = convertArithOverflowFlagsToLLVM(arithAttr.getValue()); |
| 93 | } |
| 94 | } |
| 95 | |
| 96 | ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); } |
| 97 | LLVM::IntegerOverflowFlags getOverflowFlags() const { return overflowFlags; } |
| 98 | |
| 99 | private: |
| 100 | NamedAttrList convertedAttr; |
| 101 | LLVM::IntegerOverflowFlags overflowFlags = LLVM::IntegerOverflowFlags::none; |
| 102 | }; |
| 103 | |
| 104 | template <typename SourceOp, typename TargetOp> |
| 105 | class AttrConverterConstrainedFPToLLVM { |
| 106 | static_assert(TargetOp::template hasTrait< |
| 107 | LLVM::FPExceptionBehaviorOpInterface::Trait>(), |
| 108 | "Target constrained FP operations must implement " |
| 109 | "LLVM::FPExceptionBehaviorOpInterface" ); |
| 110 | |
| 111 | public: |
| 112 | AttrConverterConstrainedFPToLLVM(SourceOp srcOp) { |
| 113 | // Copy the source attributes. |
| 114 | convertedAttr = NamedAttrList{srcOp->getAttrs()}; |
| 115 | |
| 116 | if constexpr (TargetOp::template hasTrait< |
| 117 | LLVM::RoundingModeOpInterface::Trait>()) { |
| 118 | // Get the name of the rounding mode attribute. |
| 119 | StringRef arithAttrName = srcOp.getRoundingModeAttrName(); |
| 120 | // Remove the source attribute. |
| 121 | auto arithAttr = |
| 122 | cast<arith::RoundingModeAttr>(convertedAttr.erase(arithAttrName)); |
| 123 | // Set the target attribute. |
| 124 | convertedAttr.set(TargetOp::getRoundingModeAttrName(), |
| 125 | convertArithRoundingModeAttrToLLVM(arithAttr)); |
| 126 | } |
| 127 | convertedAttr.set(TargetOp::getFPExceptionBehaviorAttrName(), |
| 128 | getLLVMDefaultFPExceptionBehavior(*srcOp->getContext())); |
| 129 | } |
| 130 | |
| 131 | ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); } |
| 132 | LLVM::IntegerOverflowFlags getOverflowFlags() const { |
| 133 | return LLVM::IntegerOverflowFlags::none; |
| 134 | } |
| 135 | |
| 136 | private: |
| 137 | NamedAttrList convertedAttr; |
| 138 | }; |
| 139 | |
| 140 | } // namespace arith |
| 141 | } // namespace mlir |
| 142 | |
| 143 | #endif // MLIR_CONVERSION_ARITHCOMMON_ATTRTOLLVMCONVERTER_H |
| 144 | |