| 1 | //===- BytecodeImplementation.h - MLIR Bytecode Implementation --*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This header defines various interfaces and utilities necessary for dialects |
| 10 | // to hook into bytecode serialization. |
| 11 | // |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #ifndef MLIR_BYTECODE_BYTECODEIMPLEMENTATION_H |
| 15 | #define MLIR_BYTECODE_BYTECODEIMPLEMENTATION_H |
| 16 | |
| 17 | #include "mlir/IR/Attributes.h" |
| 18 | #include "mlir/IR/Diagnostics.h" |
| 19 | #include "mlir/IR/Dialect.h" |
| 20 | #include "mlir/IR/DialectInterface.h" |
| 21 | #include "mlir/IR/OpImplementation.h" |
| 22 | #include "llvm/ADT/STLExtras.h" |
| 23 | #include "llvm/ADT/Twine.h" |
| 24 | |
| 25 | namespace mlir { |
| 26 | //===--------------------------------------------------------------------===// |
| 27 | // Dialect Version Interface. |
| 28 | //===--------------------------------------------------------------------===// |
| 29 | |
| 30 | /// This class is used to represent the version of a dialect, for the purpose |
| 31 | /// of polymorphic destruction. |
| 32 | class DialectVersion { |
| 33 | public: |
| 34 | virtual ~DialectVersion() = default; |
| 35 | }; |
| 36 | |
| 37 | //===----------------------------------------------------------------------===// |
| 38 | // DialectBytecodeReader |
| 39 | //===----------------------------------------------------------------------===// |
| 40 | |
| 41 | /// This class defines a virtual interface for reading a bytecode stream, |
| 42 | /// providing hooks into the bytecode reader. As such, this class should only be |
| 43 | /// derived and defined by the main bytecode reader, users (i.e. dialects) |
| 44 | /// should generally only interact with this class via the |
| 45 | /// BytecodeDialectInterface below. |
| 46 | class DialectBytecodeReader { |
| 47 | public: |
| 48 | virtual ~DialectBytecodeReader() = default; |
| 49 | |
| 50 | /// Emit an error to the reader. |
| 51 | virtual InFlightDiagnostic emitError(const Twine &msg = {}) const = 0; |
| 52 | |
| 53 | /// Retrieve the dialect version by name if available. |
| 54 | virtual FailureOr<const DialectVersion *> |
| 55 | getDialectVersion(StringRef dialectName) const = 0; |
| 56 | template <class T> |
| 57 | FailureOr<const DialectVersion *> getDialectVersion() const { |
| 58 | return getDialectVersion(T::getDialectNamespace()); |
| 59 | } |
| 60 | |
| 61 | /// Retrieve the context associated to the reader. |
| 62 | virtual MLIRContext *getContext() const = 0; |
| 63 | |
| 64 | /// Return the bytecode version being read. |
| 65 | virtual uint64_t getBytecodeVersion() const = 0; |
| 66 | |
| 67 | /// Read out a list of elements, invoking the provided callback for each |
| 68 | /// element. The callback function may be in any of the following forms: |
| 69 | /// * LogicalResult(T &) |
| 70 | /// * FailureOr<T>() |
| 71 | template <typename T, typename CallbackFn> |
| 72 | LogicalResult readList(SmallVectorImpl<T> &result, CallbackFn &&callback) { |
| 73 | uint64_t size; |
| 74 | if (failed(Result: readVarInt(result&: size))) |
| 75 | return failure(); |
| 76 | result.reserve(size); |
| 77 | |
| 78 | for (uint64_t i = 0; i < size; ++i) { |
| 79 | // Check if the callback uses FailureOr, or populates the result by |
| 80 | // reference. |
| 81 | if constexpr (llvm::function_traits<std::decay_t<CallbackFn>>::num_args) { |
| 82 | T element = {}; |
| 83 | if (failed(callback(element))) |
| 84 | return failure(); |
| 85 | result.emplace_back(std::move(element)); |
| 86 | } else { |
| 87 | FailureOr<T> element = callback(); |
| 88 | if (failed(element)) |
| 89 | return failure(); |
| 90 | result.emplace_back(std::move(*element)); |
| 91 | } |
| 92 | } |
| 93 | return success(); |
| 94 | } |
| 95 | |
| 96 | //===--------------------------------------------------------------------===// |
| 97 | // IR |
| 98 | //===--------------------------------------------------------------------===// |
| 99 | |
| 100 | /// Read a reference to the given attribute. |
| 101 | virtual LogicalResult readAttribute(Attribute &result) = 0; |
| 102 | /// Read an optional reference to the given attribute. Returns success even if |
| 103 | /// the Attribute isn't present. |
| 104 | virtual LogicalResult readOptionalAttribute(Attribute &attr) = 0; |
| 105 | |
| 106 | template <typename T> |
| 107 | LogicalResult readAttributes(SmallVectorImpl<T> &attrs) { |
| 108 | return readList(attrs, [this](T &attr) { return readAttribute(attr); }); |
| 109 | } |
| 110 | template <typename T> |
| 111 | LogicalResult readAttribute(T &result) { |
| 112 | Attribute baseResult; |
| 113 | if (failed(Result: readAttribute(result&: baseResult))) |
| 114 | return failure(); |
| 115 | if ((result = dyn_cast<T>(baseResult))) |
| 116 | return success(); |
| 117 | return emitError() << "expected " << llvm::getTypeName<T>() |
| 118 | << ", but got: " << baseResult; |
| 119 | } |
| 120 | template <typename T> |
| 121 | LogicalResult readOptionalAttribute(T &result) { |
| 122 | Attribute baseResult; |
| 123 | if (failed(Result: readOptionalAttribute(attr&: baseResult))) |
| 124 | return failure(); |
| 125 | if (!baseResult) |
| 126 | return success(); |
| 127 | if ((result = dyn_cast<T>(baseResult))) |
| 128 | return success(); |
| 129 | return emitError() << "expected " << llvm::getTypeName<T>() |
| 130 | << ", but got: " << baseResult; |
| 131 | } |
| 132 | |
| 133 | /// Read a reference to the given type. |
| 134 | virtual LogicalResult readType(Type &result) = 0; |
| 135 | template <typename T> |
| 136 | LogicalResult readTypes(SmallVectorImpl<T> &types) { |
| 137 | return readList(types, [this](T &type) { return readType(type); }); |
| 138 | } |
| 139 | template <typename T> |
| 140 | LogicalResult readType(T &result) { |
| 141 | Type baseResult; |
| 142 | if (failed(Result: readType(result&: baseResult))) |
| 143 | return failure(); |
| 144 | if ((result = dyn_cast<T>(baseResult))) |
| 145 | return success(); |
| 146 | return emitError() << "expected " << llvm::getTypeName<T>() |
| 147 | << ", but got: " << baseResult; |
| 148 | } |
| 149 | |
| 150 | /// Read a handle to a dialect resource. |
| 151 | template <typename ResourceT> |
| 152 | FailureOr<ResourceT> readResourceHandle() { |
| 153 | FailureOr<AsmDialectResourceHandle> handle = readResourceHandle(); |
| 154 | if (failed(Result: handle)) |
| 155 | return failure(); |
| 156 | if (auto *result = dyn_cast<ResourceT>(&*handle)) |
| 157 | return std::move(*result); |
| 158 | return emitError() << "provided resource handle differs from the " |
| 159 | "expected resource type" ; |
| 160 | } |
| 161 | |
| 162 | //===--------------------------------------------------------------------===// |
| 163 | // Primitives |
| 164 | //===--------------------------------------------------------------------===// |
| 165 | |
| 166 | /// Read a variable width integer. |
| 167 | virtual LogicalResult readVarInt(uint64_t &result) = 0; |
| 168 | |
| 169 | /// Read a signed variable width integer. |
| 170 | virtual LogicalResult readSignedVarInt(int64_t &result) = 0; |
| 171 | LogicalResult readSignedVarInts(SmallVectorImpl<int64_t> &result) { |
| 172 | return readList(result, |
| 173 | callback: [this](int64_t &value) { return readSignedVarInt(result&: value); }); |
| 174 | } |
| 175 | |
| 176 | /// Parse a variable length encoded integer whose low bit is used to encode an |
| 177 | /// unrelated flag, i.e: `(integerValue << 1) | (flag ? 1 : 0)`. |
| 178 | LogicalResult readVarIntWithFlag(uint64_t &result, bool &flag) { |
| 179 | if (failed(Result: readVarInt(result))) |
| 180 | return failure(); |
| 181 | flag = result & 1; |
| 182 | result >>= 1; |
| 183 | return success(); |
| 184 | } |
| 185 | |
| 186 | /// Read a "small" sparse array of integer <= 32 bits elements, where |
| 187 | /// index/value pairs can be compressed when the array is small. |
| 188 | /// Note that only some position of the array will be read and the ones |
| 189 | /// not stored in the bytecode are gonne be left untouched. |
| 190 | /// If the provided array is too small for the stored indices, an error |
| 191 | /// will be returned. |
| 192 | template <typename T> |
| 193 | LogicalResult readSparseArray(MutableArrayRef<T> array) { |
| 194 | static_assert(sizeof(T) < sizeof(uint64_t), "expect integer < 64 bits" ); |
| 195 | static_assert(std::is_integral<T>::value, "expects integer" ); |
| 196 | uint64_t nonZeroesCount; |
| 197 | bool useSparseEncoding; |
| 198 | if (failed(Result: readVarIntWithFlag(result&: nonZeroesCount, flag&: useSparseEncoding))) |
| 199 | return failure(); |
| 200 | if (nonZeroesCount == 0) |
| 201 | return success(); |
| 202 | if (!useSparseEncoding) { |
| 203 | // This is a simple dense array. |
| 204 | if (nonZeroesCount > array.size()) { |
| 205 | emitError(msg: "trying to read an array of " ) |
| 206 | << nonZeroesCount << " but only " << array.size() |
| 207 | << " storage available." ; |
| 208 | return failure(); |
| 209 | } |
| 210 | for (int64_t index : llvm::seq<int64_t>(Begin: 0, End: nonZeroesCount)) { |
| 211 | uint64_t value; |
| 212 | if (failed(Result: readVarInt(result&: value))) |
| 213 | return failure(); |
| 214 | array[index] = value; |
| 215 | } |
| 216 | return success(); |
| 217 | } |
| 218 | // Read sparse encoding |
| 219 | // This is the number of bits used for packing the index with the value. |
| 220 | uint64_t indexBitSize; |
| 221 | if (failed(Result: readVarInt(result&: indexBitSize))) |
| 222 | return failure(); |
| 223 | constexpr uint64_t maxIndexBitSize = 8; |
| 224 | if (indexBitSize > maxIndexBitSize) { |
| 225 | emitError(msg: "reading sparse array with indexing above 8 bits: " ) |
| 226 | << indexBitSize; |
| 227 | return failure(); |
| 228 | } |
| 229 | for (uint32_t count : llvm::seq<uint32_t>(Begin: 0, End: nonZeroesCount)) { |
| 230 | (void)count; |
| 231 | uint64_t indexValuePair; |
| 232 | if (failed(Result: readVarInt(result&: indexValuePair))) |
| 233 | return failure(); |
| 234 | uint64_t index = indexValuePair & ~(uint64_t(-1) << (indexBitSize)); |
| 235 | uint64_t value = indexValuePair >> indexBitSize; |
| 236 | if (index >= array.size()) { |
| 237 | emitError(msg: "reading a sparse array found index " ) |
| 238 | << index << " but only " << array.size() << " storage available." ; |
| 239 | return failure(); |
| 240 | } |
| 241 | array[index] = value; |
| 242 | } |
| 243 | return success(); |
| 244 | } |
| 245 | |
| 246 | /// Read an APInt that is known to have been encoded with the given width. |
| 247 | virtual FailureOr<APInt> readAPIntWithKnownWidth(unsigned bitWidth) = 0; |
| 248 | |
| 249 | /// Read an APFloat that is known to have been encoded with the given |
| 250 | /// semantics. |
| 251 | virtual FailureOr<APFloat> |
| 252 | readAPFloatWithKnownSemantics(const llvm::fltSemantics &semantics) = 0; |
| 253 | |
| 254 | /// Read a string from the bytecode. |
| 255 | virtual LogicalResult readString(StringRef &result) = 0; |
| 256 | |
| 257 | /// Read a blob from the bytecode. |
| 258 | virtual LogicalResult readBlob(ArrayRef<char> &result) = 0; |
| 259 | |
| 260 | /// Read a bool from the bytecode. |
| 261 | virtual LogicalResult readBool(bool &result) = 0; |
| 262 | |
| 263 | private: |
| 264 | /// Read a handle to a dialect resource. |
| 265 | virtual FailureOr<AsmDialectResourceHandle> readResourceHandle() = 0; |
| 266 | }; |
| 267 | |
| 268 | //===----------------------------------------------------------------------===// |
| 269 | // DialectBytecodeWriter |
| 270 | //===----------------------------------------------------------------------===// |
| 271 | |
| 272 | /// This class defines a virtual interface for writing to a bytecode stream, |
| 273 | /// providing hooks into the bytecode writer. As such, this class should only be |
| 274 | /// derived and defined by the main bytecode writer, users (i.e. dialects) |
| 275 | /// should generally only interact with this class via the |
| 276 | /// BytecodeDialectInterface below. |
| 277 | class DialectBytecodeWriter { |
| 278 | public: |
| 279 | virtual ~DialectBytecodeWriter() = default; |
| 280 | |
| 281 | //===--------------------------------------------------------------------===// |
| 282 | // IR |
| 283 | //===--------------------------------------------------------------------===// |
| 284 | |
| 285 | /// Write out a list of elements, invoking the provided callback for each |
| 286 | /// element. |
| 287 | template <typename RangeT, typename CallbackFn> |
| 288 | void writeList(RangeT &&range, CallbackFn &&callback) { |
| 289 | writeVarInt(value: llvm::size(range)); |
| 290 | for (auto &element : range) |
| 291 | callback(element); |
| 292 | } |
| 293 | |
| 294 | /// Write a reference to the given attribute. |
| 295 | virtual void writeAttribute(Attribute attr) = 0; |
| 296 | virtual void writeOptionalAttribute(Attribute attr) = 0; |
| 297 | template <typename T> |
| 298 | void writeAttributes(ArrayRef<T> attrs) { |
| 299 | writeList(attrs, [this](T attr) { writeAttribute(attr); }); |
| 300 | } |
| 301 | |
| 302 | /// Write a reference to the given type. |
| 303 | virtual void writeType(Type type) = 0; |
| 304 | template <typename T> |
| 305 | void writeTypes(ArrayRef<T> types) { |
| 306 | writeList(types, [this](T type) { writeType(type); }); |
| 307 | } |
| 308 | |
| 309 | /// Write the given handle to a dialect resource. |
| 310 | virtual void |
| 311 | writeResourceHandle(const AsmDialectResourceHandle &resource) = 0; |
| 312 | |
| 313 | //===--------------------------------------------------------------------===// |
| 314 | // Primitives |
| 315 | //===--------------------------------------------------------------------===// |
| 316 | |
| 317 | /// Write a variable width integer to the output stream. This should be the |
| 318 | /// preferred method for emitting integers whenever possible. |
| 319 | virtual void writeVarInt(uint64_t value) = 0; |
| 320 | |
| 321 | /// Write a signed variable width integer to the output stream. This should be |
| 322 | /// the preferred method for emitting signed integers whenever possible. |
| 323 | virtual void writeSignedVarInt(int64_t value) = 0; |
| 324 | void writeSignedVarInts(ArrayRef<int64_t> value) { |
| 325 | writeList(range&: value, callback: [this](int64_t value) { writeSignedVarInt(value); }); |
| 326 | } |
| 327 | |
| 328 | /// Write a VarInt and a flag packed together. |
| 329 | void writeVarIntWithFlag(uint64_t value, bool flag) { |
| 330 | writeVarInt(value: (value << 1) | (flag ? 1 : 0)); |
| 331 | } |
| 332 | |
| 333 | /// Write out a "small" sparse array of integer <= 32 bits elements, where |
| 334 | /// index/value pairs can be compressed when the array is small. This method |
| 335 | /// will scan the array multiple times and should not be used for large |
| 336 | /// arrays. The optional provided "zero" can be used to adjust for the |
| 337 | /// expected repeated value. We assume here that the array size fits in a 32 |
| 338 | /// bits integer. |
| 339 | template <typename T> |
| 340 | void writeSparseArray(ArrayRef<T> array) { |
| 341 | static_assert(sizeof(T) < sizeof(uint64_t), "expect integer < 64 bits" ); |
| 342 | static_assert(std::is_integral<T>::value, "expects integer" ); |
| 343 | uint32_t size = array.size(); |
| 344 | uint32_t nonZeroesCount = 0, lastIndex = 0; |
| 345 | for (uint32_t index : llvm::seq<uint32_t>(Begin: 0, End: size)) { |
| 346 | if (!array[index]) |
| 347 | continue; |
| 348 | nonZeroesCount++; |
| 349 | lastIndex = index; |
| 350 | } |
| 351 | // If the last position is too large, or the array isn't at least 50% |
| 352 | // sparse, emit it with a dense encoding. |
| 353 | if (lastIndex > 256 || nonZeroesCount > size / 2) { |
| 354 | // Emit the array size and a flag which indicates whether it is sparse. |
| 355 | writeVarIntWithFlag(value: size, flag: false); |
| 356 | for (const T &elt : array) |
| 357 | writeVarInt(value: elt); |
| 358 | return; |
| 359 | } |
| 360 | // Emit sparse: first the number of elements we'll write and a flag |
| 361 | // indicating it is a sparse encoding. |
| 362 | writeVarIntWithFlag(value: nonZeroesCount, flag: true); |
| 363 | if (nonZeroesCount == 0) |
| 364 | return; |
| 365 | // This is the number of bits used for packing the index with the value. |
| 366 | int indexBitSize = llvm::Log2_32_Ceil(Value: lastIndex + 1); |
| 367 | writeVarInt(value: indexBitSize); |
| 368 | for (uint32_t index : llvm::seq<uint32_t>(Begin: 0, End: lastIndex + 1)) { |
| 369 | T value = array[index]; |
| 370 | if (!value) |
| 371 | continue; |
| 372 | uint64_t indexValuePair = (value << indexBitSize) | (index); |
| 373 | writeVarInt(value: indexValuePair); |
| 374 | } |
| 375 | } |
| 376 | |
| 377 | /// Write an APInt to the bytecode stream whose bitwidth will be known |
| 378 | /// externally at read time. This method is useful for encoding APInt values |
| 379 | /// when the width is known via external means, such as via a type. This |
| 380 | /// method should generally only be invoked if you need an APInt, otherwise |
| 381 | /// use the varint methods above. APInt values are generally encoded using |
| 382 | /// zigzag encoding, to enable more efficient encodings for negative values. |
| 383 | virtual void writeAPIntWithKnownWidth(const APInt &value) = 0; |
| 384 | |
| 385 | /// Write an APFloat to the bytecode stream whose semantics will be known |
| 386 | /// externally at read time. This method is useful for encoding APFloat values |
| 387 | /// when the semantics are known via external means, such as via a type. |
| 388 | virtual void writeAPFloatWithKnownSemantics(const APFloat &value) = 0; |
| 389 | |
| 390 | /// Write a string to the bytecode, which is owned by the caller and is |
| 391 | /// guaranteed to not die before the end of the bytecode process. This should |
| 392 | /// only be called if such a guarantee can be made, such as when the string is |
| 393 | /// owned by an attribute or type. |
| 394 | virtual void writeOwnedString(StringRef str) = 0; |
| 395 | |
| 396 | /// Write a blob to the bytecode, which is owned by the caller and is |
| 397 | /// guaranteed to not die before the end of the bytecode process. The blob is |
| 398 | /// written as-is, with no additional compression or compaction. |
| 399 | virtual void writeOwnedBlob(ArrayRef<char> blob) = 0; |
| 400 | |
| 401 | /// Write a bool to the output stream. |
| 402 | virtual void writeOwnedBool(bool value) = 0; |
| 403 | |
| 404 | /// Return the bytecode version being emitted for. |
| 405 | virtual int64_t getBytecodeVersion() const = 0; |
| 406 | |
| 407 | /// Retrieve the dialect version by name if available. |
| 408 | virtual FailureOr<const DialectVersion *> |
| 409 | getDialectVersion(StringRef dialectName) const = 0; |
| 410 | |
| 411 | template <class T> |
| 412 | FailureOr<const DialectVersion *> getDialectVersion() const { |
| 413 | return getDialectVersion(T::getDialectNamespace()); |
| 414 | } |
| 415 | }; |
| 416 | |
| 417 | //===----------------------------------------------------------------------===// |
| 418 | // BytecodeDialectInterface |
| 419 | //===----------------------------------------------------------------------===// |
| 420 | |
| 421 | class BytecodeDialectInterface |
| 422 | : public DialectInterface::Base<BytecodeDialectInterface> { |
| 423 | public: |
| 424 | using Base::Base; |
| 425 | |
| 426 | //===--------------------------------------------------------------------===// |
| 427 | // Reading |
| 428 | //===--------------------------------------------------------------------===// |
| 429 | |
| 430 | /// Read an attribute belonging to this dialect from the given reader. This |
| 431 | /// method should return null in the case of failure. Optionally, the dialect |
| 432 | /// version can be accessed through the reader. |
| 433 | virtual Attribute readAttribute(DialectBytecodeReader &reader) const { |
| 434 | reader.emitError() << "dialect " << getDialect()->getNamespace() |
| 435 | << " does not support reading attributes from bytecode" ; |
| 436 | return Attribute(); |
| 437 | } |
| 438 | |
| 439 | /// Read a type belonging to this dialect from the given reader. This method |
| 440 | /// should return null in the case of failure. Optionally, the dialect version |
| 441 | /// can be accessed thorugh the reader. |
| 442 | virtual Type readType(DialectBytecodeReader &reader) const { |
| 443 | reader.emitError() << "dialect " << getDialect()->getNamespace() |
| 444 | << " does not support reading types from bytecode" ; |
| 445 | return Type(); |
| 446 | } |
| 447 | |
| 448 | //===--------------------------------------------------------------------===// |
| 449 | // Writing |
| 450 | //===--------------------------------------------------------------------===// |
| 451 | |
| 452 | /// Write the given attribute, which belongs to this dialect, to the given |
| 453 | /// writer. This method may return failure to indicate that the given |
| 454 | /// attribute could not be encoded, in which case the textual format will be |
| 455 | /// used to encode this attribute instead. |
| 456 | virtual LogicalResult writeAttribute(Attribute attr, |
| 457 | DialectBytecodeWriter &writer) const { |
| 458 | return failure(); |
| 459 | } |
| 460 | |
| 461 | /// Write the given type, which belongs to this dialect, to the given writer. |
| 462 | /// This method may return failure to indicate that the given type could not |
| 463 | /// be encoded, in which case the textual format will be used to encode this |
| 464 | /// type instead. |
| 465 | virtual LogicalResult writeType(Type type, |
| 466 | DialectBytecodeWriter &writer) const { |
| 467 | return failure(); |
| 468 | } |
| 469 | |
| 470 | /// Write the version of this dialect to the given writer. |
| 471 | virtual void writeVersion(DialectBytecodeWriter &writer) const {} |
| 472 | |
| 473 | // Read the version of this dialect from the provided reader and return it as |
| 474 | // a `unique_ptr` to a dialect version object. |
| 475 | virtual std::unique_ptr<DialectVersion> |
| 476 | readVersion(DialectBytecodeReader &reader) const { |
| 477 | reader.emitError(msg: "Dialect does not support versioning" ); |
| 478 | return nullptr; |
| 479 | } |
| 480 | |
| 481 | /// Hook invoked after parsing completed, if a version directive was present |
| 482 | /// and included an entry for the current dialect. This hook offers the |
| 483 | /// opportunity to the dialect to visit the IR and upgrades constructs emitted |
| 484 | /// by the version of the dialect corresponding to the provided version. |
| 485 | virtual LogicalResult |
| 486 | upgradeFromVersion(Operation *topLevelOp, |
| 487 | const DialectVersion &version) const { |
| 488 | return success(); |
| 489 | } |
| 490 | }; |
| 491 | |
| 492 | /// Helper for resource handle reading that returns LogicalResult. |
| 493 | template <typename T, typename... Ts> |
| 494 | static LogicalResult readResourceHandle(DialectBytecodeReader &reader, |
| 495 | FailureOr<T> &value, Ts &&...params) { |
| 496 | FailureOr<T> handle = reader.readResourceHandle<T>(); |
| 497 | if (failed(handle)) |
| 498 | return failure(); |
| 499 | if (auto *result = dyn_cast<T>(&*handle)) { |
| 500 | value = std::move(*result); |
| 501 | return success(); |
| 502 | } |
| 503 | return failure(); |
| 504 | } |
| 505 | |
| 506 | /// Helper method that injects context only if needed, this helps unify some of |
| 507 | /// the attribute construction methods. |
| 508 | template <typename T, typename... Ts> |
| 509 | auto get(MLIRContext *context, Ts &&...params) { |
| 510 | // Prefer a direct `get` method if one exists. |
| 511 | if constexpr (llvm::is_detected<detail::has_get_method, T, Ts...>::value) { |
| 512 | (void)context; |
| 513 | return T::get(std::forward<Ts>(params)...); |
| 514 | } else if constexpr (llvm::is_detected<detail::has_get_method, T, |
| 515 | MLIRContext *, Ts...>::value) { |
| 516 | return T::get(context, std::forward<Ts>(params)...); |
| 517 | } else { |
| 518 | // Otherwise, pass to the base get. |
| 519 | return T::Base::get(context, std::forward<Ts>(params)...); |
| 520 | } |
| 521 | } |
| 522 | |
| 523 | } // namespace mlir |
| 524 | |
| 525 | #endif // MLIR_BYTECODE_BYTECODEIMPLEMENTATION_H |
| 526 | |