| 1 | //===- FunctionSupport.cpp - Utility types for function-like ops ----------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Interfaces/FunctionInterfaces.h" |
| 10 | |
| 11 | using namespace mlir; |
| 12 | |
| 13 | //===----------------------------------------------------------------------===// |
| 14 | // Tablegen Interface Definitions |
| 15 | //===----------------------------------------------------------------------===// |
| 16 | |
| 17 | #include "mlir/Interfaces/FunctionInterfaces.cpp.inc" |
| 18 | |
| 19 | //===----------------------------------------------------------------------===// |
| 20 | // Function Arguments and Results. |
| 21 | //===----------------------------------------------------------------------===// |
| 22 | |
| 23 | static bool isEmptyAttrDict(Attribute attr) { |
| 24 | return llvm::cast<DictionaryAttr>(attr).empty(); |
| 25 | } |
| 26 | |
| 27 | DictionaryAttr function_interface_impl::getArgAttrDict(FunctionOpInterface op, |
| 28 | unsigned index) { |
| 29 | ArrayAttr attrs = op.getArgAttrsAttr(); |
| 30 | DictionaryAttr argAttrs = |
| 31 | attrs ? llvm::cast<DictionaryAttr>(attrs[index]) : DictionaryAttr(); |
| 32 | return argAttrs; |
| 33 | } |
| 34 | |
| 35 | DictionaryAttr |
| 36 | function_interface_impl::getResultAttrDict(FunctionOpInterface op, |
| 37 | unsigned index) { |
| 38 | ArrayAttr attrs = op.getResAttrsAttr(); |
| 39 | DictionaryAttr resAttrs = |
| 40 | attrs ? llvm::cast<DictionaryAttr>(attrs[index]) : DictionaryAttr(); |
| 41 | return resAttrs; |
| 42 | } |
| 43 | |
| 44 | ArrayRef<NamedAttribute> |
| 45 | function_interface_impl::getArgAttrs(FunctionOpInterface op, unsigned index) { |
| 46 | auto argDict = getArgAttrDict(op, index); |
| 47 | return argDict ? argDict.getValue() : std::nullopt; |
| 48 | } |
| 49 | |
| 50 | ArrayRef<NamedAttribute> |
| 51 | function_interface_impl::getResultAttrs(FunctionOpInterface op, |
| 52 | unsigned index) { |
| 53 | auto resultDict = getResultAttrDict(op, index); |
| 54 | return resultDict ? resultDict.getValue() : std::nullopt; |
| 55 | } |
| 56 | |
| 57 | /// Get either the argument or result attributes array. |
| 58 | template <bool isArg> |
| 59 | static ArrayAttr getArgResAttrs(FunctionOpInterface op) { |
| 60 | if constexpr (isArg) |
| 61 | return op.getArgAttrsAttr(); |
| 62 | else |
| 63 | return op.getResAttrsAttr(); |
| 64 | } |
| 65 | |
| 66 | /// Set either the argument or result attributes array. |
| 67 | template <bool isArg> |
| 68 | static void setArgResAttrs(FunctionOpInterface op, ArrayAttr attrs) { |
| 69 | if constexpr (isArg) |
| 70 | op.setArgAttrsAttr(attrs); |
| 71 | else |
| 72 | op.setResAttrsAttr(attrs); |
| 73 | } |
| 74 | |
| 75 | /// Erase either the argument or result attributes array. |
| 76 | template <bool isArg> |
| 77 | static void removeArgResAttrs(FunctionOpInterface op) { |
| 78 | if constexpr (isArg) |
| 79 | op.removeArgAttrsAttr(); |
| 80 | else |
| 81 | op.removeResAttrsAttr(); |
| 82 | } |
| 83 | |
| 84 | /// Set all of the argument or result attribute dictionaries for a function. |
| 85 | template <bool isArg> |
| 86 | static void setAllArgResAttrDicts(FunctionOpInterface op, |
| 87 | ArrayRef<Attribute> attrs) { |
| 88 | if (llvm::all_of(Range&: attrs, P: isEmptyAttrDict)) |
| 89 | removeArgResAttrs<isArg>(op); |
| 90 | else |
| 91 | setArgResAttrs<isArg>(op, ArrayAttr::get(op->getContext(), attrs)); |
| 92 | } |
| 93 | |
| 94 | void function_interface_impl::setAllArgAttrDicts( |
| 95 | FunctionOpInterface op, ArrayRef<DictionaryAttr> attrs) { |
| 96 | setAllArgAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size())); |
| 97 | } |
| 98 | |
| 99 | void function_interface_impl::setAllArgAttrDicts(FunctionOpInterface op, |
| 100 | ArrayRef<Attribute> attrs) { |
| 101 | auto wrappedAttrs = llvm::map_range(C&: attrs, F: [op](Attribute attr) -> Attribute { |
| 102 | return !attr ? DictionaryAttr::get(op->getContext()) : attr; |
| 103 | }); |
| 104 | setAllArgResAttrDicts</*isArg=*/true>(op, llvm::to_vector<8>(Range&: wrappedAttrs)); |
| 105 | } |
| 106 | |
| 107 | void function_interface_impl::setAllResultAttrDicts( |
| 108 | FunctionOpInterface op, ArrayRef<DictionaryAttr> attrs) { |
| 109 | setAllResultAttrDicts(op, ArrayRef<Attribute>(attrs.data(), attrs.size())); |
| 110 | } |
| 111 | |
| 112 | void function_interface_impl::setAllResultAttrDicts(FunctionOpInterface op, |
| 113 | ArrayRef<Attribute> attrs) { |
| 114 | auto wrappedAttrs = llvm::map_range(C&: attrs, F: [op](Attribute attr) -> Attribute { |
| 115 | return !attr ? DictionaryAttr::get(op->getContext()) : attr; |
| 116 | }); |
| 117 | setAllArgResAttrDicts</*isArg=*/false>(op, llvm::to_vector<8>(Range&: wrappedAttrs)); |
| 118 | } |
| 119 | |
| 120 | /// Update the given index into an argument or result attribute dictionary. |
| 121 | template <bool isArg> |
| 122 | static void setArgResAttrDict(FunctionOpInterface op, unsigned numTotalIndices, |
| 123 | unsigned index, DictionaryAttr attrs) { |
| 124 | ArrayAttr allAttrs = getArgResAttrs<isArg>(op); |
| 125 | if (!allAttrs) { |
| 126 | if (attrs.empty()) |
| 127 | return; |
| 128 | |
| 129 | // If this attribute is not empty, we need to create a new attribute array. |
| 130 | SmallVector<Attribute, 8> newAttrs(numTotalIndices, |
| 131 | DictionaryAttr::get(op->getContext())); |
| 132 | newAttrs[index] = attrs; |
| 133 | setArgResAttrs<isArg>(op, ArrayAttr::get(op->getContext(), newAttrs)); |
| 134 | return; |
| 135 | } |
| 136 | // Check to see if the attribute is different from what we already have. |
| 137 | if (allAttrs[index] == attrs) |
| 138 | return; |
| 139 | |
| 140 | // If it is, check to see if the attribute array would now contain only empty |
| 141 | // dictionaries. |
| 142 | ArrayRef<Attribute> rawAttrArray = allAttrs.getValue(); |
| 143 | if (attrs.empty() && |
| 144 | llvm::all_of(Range: rawAttrArray.take_front(N: index), P: isEmptyAttrDict) && |
| 145 | llvm::all_of(Range: rawAttrArray.drop_front(N: index + 1), P: isEmptyAttrDict)) |
| 146 | return removeArgResAttrs<isArg>(op); |
| 147 | |
| 148 | // Otherwise, create a new attribute array with the updated dictionary. |
| 149 | SmallVector<Attribute, 8> newAttrs(rawAttrArray); |
| 150 | newAttrs[index] = attrs; |
| 151 | setArgResAttrs<isArg>(op, ArrayAttr::get(op->getContext(), newAttrs)); |
| 152 | } |
| 153 | |
| 154 | void function_interface_impl::setArgAttrs(FunctionOpInterface op, |
| 155 | unsigned index, |
| 156 | ArrayRef<NamedAttribute> attributes) { |
| 157 | assert(index < op.getNumArguments() && "invalid argument number" ); |
| 158 | return setArgResAttrDict</*isArg=*/true>( |
| 159 | op, op.getNumArguments(), index, |
| 160 | DictionaryAttr::get(op->getContext(), attributes)); |
| 161 | } |
| 162 | |
| 163 | void function_interface_impl::setArgAttrs(FunctionOpInterface op, |
| 164 | unsigned index, |
| 165 | DictionaryAttr attributes) { |
| 166 | return setArgResAttrDict</*isArg=*/true>( |
| 167 | op, op.getNumArguments(), index, |
| 168 | attributes ? attributes : DictionaryAttr::get(op->getContext())); |
| 169 | } |
| 170 | |
| 171 | void function_interface_impl::setResultAttrs( |
| 172 | FunctionOpInterface op, unsigned index, |
| 173 | ArrayRef<NamedAttribute> attributes) { |
| 174 | assert(index < op.getNumResults() && "invalid result number" ); |
| 175 | return setArgResAttrDict</*isArg=*/false>( |
| 176 | op, op.getNumResults(), index, |
| 177 | DictionaryAttr::get(op->getContext(), attributes)); |
| 178 | } |
| 179 | |
| 180 | void function_interface_impl::setResultAttrs(FunctionOpInterface op, |
| 181 | unsigned index, |
| 182 | DictionaryAttr attributes) { |
| 183 | assert(index < op.getNumResults() && "invalid result number" ); |
| 184 | return setArgResAttrDict</*isArg=*/false>( |
| 185 | op, op.getNumResults(), index, |
| 186 | attributes ? attributes : DictionaryAttr::get(op->getContext())); |
| 187 | } |
| 188 | |
| 189 | void function_interface_impl::insertFunctionArguments( |
| 190 | FunctionOpInterface op, ArrayRef<unsigned> argIndices, TypeRange argTypes, |
| 191 | ArrayRef<DictionaryAttr> argAttrs, ArrayRef<Location> argLocs, |
| 192 | unsigned originalNumArgs, Type newType) { |
| 193 | assert(argIndices.size() == argTypes.size()); |
| 194 | assert(argIndices.size() == argAttrs.size() || argAttrs.empty()); |
| 195 | assert(argIndices.size() == argLocs.size()); |
| 196 | if (argIndices.empty()) |
| 197 | return; |
| 198 | |
| 199 | // There are 3 things that need to be updated: |
| 200 | // - Function type. |
| 201 | // - Arg attrs. |
| 202 | // - Block arguments of entry block, if not empty. |
| 203 | |
| 204 | // Update the argument attributes of the function. |
| 205 | ArrayAttr oldArgAttrs = op.getArgAttrsAttr(); |
| 206 | if (oldArgAttrs || !argAttrs.empty()) { |
| 207 | SmallVector<DictionaryAttr, 4> newArgAttrs; |
| 208 | newArgAttrs.reserve(originalNumArgs + argIndices.size()); |
| 209 | unsigned oldIdx = 0; |
| 210 | auto migrate = [&](unsigned untilIdx) { |
| 211 | if (!oldArgAttrs) { |
| 212 | newArgAttrs.resize(newArgAttrs.size() + untilIdx - oldIdx); |
| 213 | } else { |
| 214 | auto oldArgAttrRange = oldArgAttrs.getAsRange<DictionaryAttr>(); |
| 215 | newArgAttrs.append(oldArgAttrRange.begin() + oldIdx, |
| 216 | oldArgAttrRange.begin() + untilIdx); |
| 217 | } |
| 218 | oldIdx = untilIdx; |
| 219 | }; |
| 220 | for (unsigned i = 0, e = argIndices.size(); i < e; ++i) { |
| 221 | migrate(argIndices[i]); |
| 222 | newArgAttrs.push_back(argAttrs.empty() ? DictionaryAttr{} : argAttrs[i]); |
| 223 | } |
| 224 | migrate(originalNumArgs); |
| 225 | setAllArgAttrDicts(op, newArgAttrs); |
| 226 | } |
| 227 | |
| 228 | // Update the function type. |
| 229 | op.setFunctionTypeAttr(TypeAttr::get(newType)); |
| 230 | |
| 231 | // Update entry block arguments, if not empty. |
| 232 | if (!op.isExternal()) { |
| 233 | Block &entry = op->getRegion(0).front(); |
| 234 | for (unsigned i = 0, e = argIndices.size(); i < e; ++i) |
| 235 | entry.insertArgument(index: argIndices[i] + i, type: argTypes[i], loc: argLocs[i]); |
| 236 | } |
| 237 | } |
| 238 | |
| 239 | void function_interface_impl::insertFunctionResults( |
| 240 | FunctionOpInterface op, ArrayRef<unsigned> resultIndices, |
| 241 | TypeRange resultTypes, ArrayRef<DictionaryAttr> resultAttrs, |
| 242 | unsigned originalNumResults, Type newType) { |
| 243 | assert(resultIndices.size() == resultTypes.size()); |
| 244 | assert(resultIndices.size() == resultAttrs.size() || resultAttrs.empty()); |
| 245 | if (resultIndices.empty()) |
| 246 | return; |
| 247 | |
| 248 | // There are 2 things that need to be updated: |
| 249 | // - Function type. |
| 250 | // - Result attrs. |
| 251 | |
| 252 | // Update the result attributes of the function. |
| 253 | ArrayAttr oldResultAttrs = op.getResAttrsAttr(); |
| 254 | if (oldResultAttrs || !resultAttrs.empty()) { |
| 255 | SmallVector<DictionaryAttr, 4> newResultAttrs; |
| 256 | newResultAttrs.reserve(originalNumResults + resultIndices.size()); |
| 257 | unsigned oldIdx = 0; |
| 258 | auto migrate = [&](unsigned untilIdx) { |
| 259 | if (!oldResultAttrs) { |
| 260 | newResultAttrs.resize(newResultAttrs.size() + untilIdx - oldIdx); |
| 261 | } else { |
| 262 | auto oldResultAttrsRange = oldResultAttrs.getAsRange<DictionaryAttr>(); |
| 263 | newResultAttrs.append(oldResultAttrsRange.begin() + oldIdx, |
| 264 | oldResultAttrsRange.begin() + untilIdx); |
| 265 | } |
| 266 | oldIdx = untilIdx; |
| 267 | }; |
| 268 | for (unsigned i = 0, e = resultIndices.size(); i < e; ++i) { |
| 269 | migrate(resultIndices[i]); |
| 270 | newResultAttrs.push_back(resultAttrs.empty() ? DictionaryAttr{} |
| 271 | : resultAttrs[i]); |
| 272 | } |
| 273 | migrate(originalNumResults); |
| 274 | setAllResultAttrDicts(op, newResultAttrs); |
| 275 | } |
| 276 | |
| 277 | // Update the function type. |
| 278 | op.setFunctionTypeAttr(TypeAttr::get(newType)); |
| 279 | } |
| 280 | |
| 281 | void function_interface_impl::eraseFunctionArguments( |
| 282 | FunctionOpInterface op, const BitVector &argIndices, Type newType) { |
| 283 | // There are 3 things that need to be updated: |
| 284 | // - Function type. |
| 285 | // - Arg attrs. |
| 286 | // - Block arguments of entry block, if not empty. |
| 287 | |
| 288 | // Update the argument attributes of the function. |
| 289 | if (ArrayAttr argAttrs = op.getArgAttrsAttr()) { |
| 290 | SmallVector<DictionaryAttr, 4> newArgAttrs; |
| 291 | newArgAttrs.reserve(argAttrs.size()); |
| 292 | for (unsigned i = 0, e = argIndices.size(); i < e; ++i) |
| 293 | if (!argIndices[i]) |
| 294 | newArgAttrs.emplace_back(llvm::cast<DictionaryAttr>(argAttrs[i])); |
| 295 | setAllArgAttrDicts(op, newArgAttrs); |
| 296 | } |
| 297 | |
| 298 | // Update the function type. |
| 299 | op.setFunctionTypeAttr(TypeAttr::get(newType)); |
| 300 | |
| 301 | // Update entry block arguments, if not empty. |
| 302 | if (!op.isExternal()) { |
| 303 | Block &entry = op->getRegion(0).front(); |
| 304 | entry.eraseArguments(eraseIndices: argIndices); |
| 305 | } |
| 306 | } |
| 307 | |
| 308 | void function_interface_impl::eraseFunctionResults( |
| 309 | FunctionOpInterface op, const BitVector &resultIndices, Type newType) { |
| 310 | // There are 2 things that need to be updated: |
| 311 | // - Function type. |
| 312 | // - Result attrs. |
| 313 | |
| 314 | // Update the result attributes of the function. |
| 315 | if (ArrayAttr resAttrs = op.getResAttrsAttr()) { |
| 316 | SmallVector<DictionaryAttr, 4> newResultAttrs; |
| 317 | newResultAttrs.reserve(resAttrs.size()); |
| 318 | for (unsigned i = 0, e = resultIndices.size(); i < e; ++i) |
| 319 | if (!resultIndices[i]) |
| 320 | newResultAttrs.emplace_back(llvm::cast<DictionaryAttr>(resAttrs[i])); |
| 321 | setAllResultAttrDicts(op, newResultAttrs); |
| 322 | } |
| 323 | |
| 324 | // Update the function type. |
| 325 | op.setFunctionTypeAttr(TypeAttr::get(newType)); |
| 326 | } |
| 327 | |
| 328 | //===----------------------------------------------------------------------===// |
| 329 | // Function type signature. |
| 330 | //===----------------------------------------------------------------------===// |
| 331 | |
| 332 | void function_interface_impl::setFunctionType(FunctionOpInterface op, |
| 333 | Type newType) { |
| 334 | unsigned oldNumArgs = op.getNumArguments(); |
| 335 | unsigned oldNumResults = op.getNumResults(); |
| 336 | op.setFunctionTypeAttr(TypeAttr::get(newType)); |
| 337 | unsigned newNumArgs = op.getNumArguments(); |
| 338 | unsigned newNumResults = op.getNumResults(); |
| 339 | |
| 340 | // Functor used to update the argument and result attributes of the function. |
| 341 | auto emptyDict = DictionaryAttr::get(op.getContext()); |
| 342 | auto updateAttrFn = [&](auto isArg, unsigned oldCount, unsigned newCount) { |
| 343 | constexpr bool isArgVal = std::is_same_v<decltype(isArg), std::true_type>; |
| 344 | |
| 345 | if (oldCount == newCount) |
| 346 | return; |
| 347 | // The new type has no arguments/results, just drop the attribute. |
| 348 | if (newCount == 0) |
| 349 | return removeArgResAttrs<isArgVal>(op); |
| 350 | ArrayAttr attrs = getArgResAttrs<isArgVal>(op); |
| 351 | if (!attrs) |
| 352 | return; |
| 353 | |
| 354 | // The new type has less arguments/results, take the first N attributes. |
| 355 | if (newCount < oldCount) |
| 356 | return setAllArgResAttrDicts<isArgVal>( |
| 357 | op, attrs.getValue().take_front(newCount)); |
| 358 | |
| 359 | // Otherwise, the new type has more arguments/results. Initialize the new |
| 360 | // arguments/results with empty dictionary attributes. |
| 361 | SmallVector<Attribute> newAttrs(attrs.begin(), attrs.end()); |
| 362 | newAttrs.resize(newCount, emptyDict); |
| 363 | setAllArgResAttrDicts<isArgVal>(op, newAttrs); |
| 364 | }; |
| 365 | |
| 366 | // Update the argument and result attributes. |
| 367 | updateAttrFn(std::true_type{}, oldNumArgs, newNumArgs); |
| 368 | updateAttrFn(std::false_type{}, oldNumResults, newNumResults); |
| 369 | } |
| 370 | |