1 | //====- LoweringHelpers.cpp - Lowering helper functions -------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file contains helper functions for lowering from CIR to LLVM or MLIR. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "clang/CIR/LoweringHelpers.h" |
14 | #include "clang/CIR/MissingFeatures.h" |
15 | |
16 | mlir::DenseElementsAttr |
17 | convertStringAttrToDenseElementsAttr(cir::ConstArrayAttr attr, |
18 | mlir::Type type) { |
19 | auto values = llvm::SmallVector<mlir::APInt, 8>{}; |
20 | const auto stringAttr = mlir::cast<mlir::StringAttr>(attr.getElts()); |
21 | |
22 | for (const char element : stringAttr) |
23 | values.push_back({8, (uint64_t)element}); |
24 | |
25 | const auto arrayTy = mlir::cast<cir::ArrayType>(attr.getType()); |
26 | if (arrayTy.getSize() != stringAttr.size()) |
27 | assert(!cir::MissingFeatures::stringTypeWithDifferentArraySize()); |
28 | |
29 | return mlir::DenseElementsAttr::get( |
30 | mlir::RankedTensorType::get({(int64_t)values.size()}, type), |
31 | llvm::ArrayRef(values)); |
32 | } |
33 | |
34 | template <> mlir::APInt getZeroInitFromType(mlir::Type ty) { |
35 | assert(mlir::isa<cir::IntType>(ty) && "expected int type" ); |
36 | const auto intTy = mlir::cast<cir::IntType>(ty); |
37 | return mlir::APInt::getZero(intTy.getWidth()); |
38 | } |
39 | |
40 | template <> mlir::APFloat getZeroInitFromType(mlir::Type ty) { |
41 | assert((mlir::isa<cir::SingleType, cir::DoubleType>(ty)) && |
42 | "only float and double supported" ); |
43 | |
44 | if (ty.isF32() || mlir::isa<cir::SingleType>(ty)) |
45 | return mlir::APFloat(0.f); |
46 | |
47 | if (ty.isF64() || mlir::isa<cir::DoubleType>(ty)) |
48 | return mlir::APFloat(0.0); |
49 | |
50 | llvm_unreachable("NYI" ); |
51 | } |
52 | |
53 | /// \param attr the ConstArrayAttr to convert |
54 | /// \param values the output parameter, the values array to fill |
55 | /// \param currentDims the shpae of tensor we're going to convert to |
56 | /// \param dimIndex the current dimension we're processing |
57 | /// \param currentIndex the current index in the values array |
58 | template <typename AttrTy, typename StorageTy> |
59 | void convertToDenseElementsAttrImpl( |
60 | cir::ConstArrayAttr attr, llvm::SmallVectorImpl<StorageTy> &values, |
61 | const llvm::SmallVectorImpl<int64_t> ¤tDims, int64_t dimIndex, |
62 | int64_t currentIndex) { |
63 | if (auto stringAttr = mlir::dyn_cast<mlir::StringAttr>(attr.getElts())) { |
64 | if (auto arrayType = mlir::dyn_cast<cir::ArrayType>(attr.getType())) { |
65 | for (auto element : stringAttr) { |
66 | auto intAttr = cir::IntAttr::get(arrayType.getElementType(), element); |
67 | values[currentIndex++] = mlir::dyn_cast<AttrTy>(intAttr).getValue(); |
68 | } |
69 | return; |
70 | } |
71 | } |
72 | |
73 | dimIndex++; |
74 | std::size_t elementsSizeInCurrentDim = 1; |
75 | for (std::size_t i = dimIndex; i < currentDims.size(); i++) |
76 | elementsSizeInCurrentDim *= currentDims[i]; |
77 | |
78 | auto arrayAttr = mlir::cast<mlir::ArrayAttr>(attr.getElts()); |
79 | for (auto eltAttr : arrayAttr) { |
80 | if (auto valueAttr = mlir::dyn_cast<AttrTy>(eltAttr)) { |
81 | values[currentIndex++] = valueAttr.getValue(); |
82 | continue; |
83 | } |
84 | |
85 | if (auto subArrayAttr = mlir::dyn_cast<cir::ConstArrayAttr>(eltAttr)) { |
86 | convertToDenseElementsAttrImpl<AttrTy>(subArrayAttr, values, currentDims, |
87 | dimIndex, currentIndex); |
88 | currentIndex += elementsSizeInCurrentDim; |
89 | continue; |
90 | } |
91 | |
92 | if (mlir::isa<cir::ZeroAttr, cir::UndefAttr>(eltAttr)) { |
93 | currentIndex += elementsSizeInCurrentDim; |
94 | continue; |
95 | } |
96 | |
97 | llvm_unreachable("unknown element in ConstArrayAttr" ); |
98 | } |
99 | } |
100 | |
101 | template <typename AttrTy, typename StorageTy> |
102 | mlir::DenseElementsAttr convertToDenseElementsAttr( |
103 | cir::ConstArrayAttr attr, const llvm::SmallVectorImpl<int64_t> &dims, |
104 | mlir::Type elementType, mlir::Type convertedElementType) { |
105 | unsigned vectorSize = 1; |
106 | for (auto dim : dims) |
107 | vectorSize *= dim; |
108 | auto values = llvm::SmallVector<StorageTy, 8>( |
109 | vectorSize, getZeroInitFromType<StorageTy>(elementType)); |
110 | convertToDenseElementsAttrImpl<AttrTy>(attr, values, dims, /*currentDim=*/0, |
111 | /*initialIndex=*/0); |
112 | return mlir::DenseElementsAttr::get( |
113 | mlir::RankedTensorType::get(dims, convertedElementType), |
114 | llvm::ArrayRef(values)); |
115 | } |
116 | |
117 | std::optional<mlir::Attribute> |
118 | lowerConstArrayAttr(cir::ConstArrayAttr constArr, |
119 | const mlir::TypeConverter *converter) { |
120 | // Ensure ConstArrayAttr has a type. |
121 | const auto typedConstArr = mlir::cast<mlir::TypedAttr>(constArr); |
122 | |
123 | // Ensure ConstArrayAttr type is a ArrayType. |
124 | const auto cirArrayType = mlir::cast<cir::ArrayType>(typedConstArr.getType()); |
125 | |
126 | // Is a ConstArrayAttr with an cir::ArrayType: fetch element type. |
127 | mlir::Type type = cirArrayType; |
128 | auto dims = llvm::SmallVector<int64_t, 2>{}; |
129 | while (auto arrayType = mlir::dyn_cast<cir::ArrayType>(type)) { |
130 | dims.push_back(Elt: arrayType.getSize()); |
131 | type = arrayType.getElementType(); |
132 | } |
133 | |
134 | if (mlir::isa<mlir::StringAttr>(constArr.getElts())) |
135 | return convertStringAttrToDenseElementsAttr(constArr, |
136 | converter->convertType(type)); |
137 | if (mlir::isa<cir::IntType>(type)) |
138 | return convertToDenseElementsAttr<cir::IntAttr, mlir::APInt>( |
139 | constArr, dims, type, converter->convertType(type)); |
140 | |
141 | if (mlir::isa<cir::CIRFPTypeInterface>(type)) |
142 | return convertToDenseElementsAttr<cir::FPAttr, mlir::APFloat>( |
143 | constArr, dims, type, converter->convertType(type)); |
144 | |
145 | return std::nullopt; |
146 | } |
147 | |