| 1 | //===- SMTAttributes.cpp - Implement SMT attributes -----------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/SMT/IR/SMTAttributes.h" |
| 10 | #include "mlir/Dialect/SMT/IR/SMTDialect.h" |
| 11 | #include "mlir/Dialect/SMT/IR/SMTTypes.h" |
| 12 | #include "mlir/IR/Builders.h" |
| 13 | #include "mlir/IR/DialectImplementation.h" |
| 14 | #include "llvm/ADT/TypeSwitch.h" |
| 15 | #include "llvm/Support/Format.h" |
| 16 | |
| 17 | using namespace mlir; |
| 18 | using namespace mlir::smt; |
| 19 | |
| 20 | //===----------------------------------------------------------------------===// |
| 21 | // BitVectorAttr |
| 22 | //===----------------------------------------------------------------------===// |
| 23 | |
| 24 | LogicalResult BitVectorAttr::verify( |
| 25 | function_ref<InFlightDiagnostic()> emitError, |
| 26 | APInt value) { // NOLINT(performance-unnecessary-value-param) |
| 27 | if (value.getBitWidth() < 1) |
| 28 | return emitError() << "bit-width must be at least 1, but got " |
| 29 | << value.getBitWidth(); |
| 30 | return success(); |
| 31 | } |
| 32 | |
| 33 | std::string BitVectorAttr::getValueAsString(bool prefix) const { |
| 34 | unsigned width = getValue().getBitWidth(); |
| 35 | SmallVector<char> toPrint; |
| 36 | StringRef pref = prefix ? "#" : "" ; |
| 37 | if (width % 4 == 0) { |
| 38 | getValue().toString(toPrint, 16, false, false, false); |
| 39 | // APInt's 'toString' omits leading zeros. However, those are critical here |
| 40 | // because they determine the bit-width of the bit-vector. |
| 41 | SmallVector<char> leadingZeros(width / 4 - toPrint.size(), '0'); |
| 42 | return (pref + "x" + Twine(leadingZeros) + toPrint).str(); |
| 43 | } |
| 44 | |
| 45 | getValue().toString(toPrint, 2, false, false, false); |
| 46 | // APInt's 'toString' omits leading zeros |
| 47 | SmallVector<char> leadingZeros(width - toPrint.size(), '0'); |
| 48 | return (pref + "b" + Twine(leadingZeros) + toPrint).str(); |
| 49 | } |
| 50 | |
| 51 | /// Parse an SMT-LIB formatted bit-vector string. |
| 52 | static FailureOr<APInt> |
| 53 | parseBitVectorString(function_ref<InFlightDiagnostic()> emitError, |
| 54 | StringRef value) { |
| 55 | if (value[0] != '#') |
| 56 | return emitError() << "expected '#'" ; |
| 57 | |
| 58 | if (value.size() < 3) |
| 59 | return emitError() << "expected at least one digit" ; |
| 60 | |
| 61 | if (value[1] == 'b') |
| 62 | return APInt(value.size() - 2, std::string(value.begin() + 2, value.end()), |
| 63 | 2); |
| 64 | |
| 65 | if (value[1] == 'x') |
| 66 | return APInt((value.size() - 2) * 4, |
| 67 | std::string(value.begin() + 2, value.end()), 16); |
| 68 | |
| 69 | return emitError() << "expected either 'b' or 'x'" ; |
| 70 | } |
| 71 | |
| 72 | BitVectorAttr BitVectorAttr::get(MLIRContext *context, StringRef value) { |
| 73 | auto maybeValue = parseBitVectorString(nullptr, value); |
| 74 | |
| 75 | assert(succeeded(maybeValue) && "string must have SMT-LIB format" ); |
| 76 | return Base::get(context, *maybeValue); |
| 77 | } |
| 78 | |
| 79 | BitVectorAttr |
| 80 | BitVectorAttr::getChecked(function_ref<InFlightDiagnostic()> emitError, |
| 81 | MLIRContext *context, StringRef value) { |
| 82 | auto maybeValue = parseBitVectorString(emitError, value); |
| 83 | if (failed(maybeValue)) |
| 84 | return {}; |
| 85 | |
| 86 | return Base::getChecked(emitError, context, *maybeValue); |
| 87 | } |
| 88 | |
| 89 | BitVectorAttr BitVectorAttr::get(MLIRContext *context, uint64_t value, |
| 90 | unsigned width) { |
| 91 | return Base::get(context, APInt(width, value)); |
| 92 | } |
| 93 | |
| 94 | BitVectorAttr |
| 95 | BitVectorAttr::getChecked(function_ref<InFlightDiagnostic()> emitError, |
| 96 | MLIRContext *context, uint64_t value, |
| 97 | unsigned width) { |
| 98 | if (width < 64 && value >= (UINT64_C(1) << width)) { |
| 99 | emitError() << "value does not fit in a bit-vector of desired width" ; |
| 100 | return {}; |
| 101 | } |
| 102 | return Base::getChecked(emitError, context, APInt(width, value)); |
| 103 | } |
| 104 | |
| 105 | Attribute BitVectorAttr::parse(AsmParser &odsParser, Type odsType) { |
| 106 | llvm::SMLoc loc = odsParser.getCurrentLocation(); |
| 107 | |
| 108 | APInt val; |
| 109 | if (odsParser.parseLess() || odsParser.parseInteger(val) || |
| 110 | odsParser.parseGreater()) |
| 111 | return {}; |
| 112 | |
| 113 | // Requires the use of `quantified(<attr>)` in operation assembly formats. |
| 114 | if (!odsType || !llvm::isa<BitVectorType>(odsType)) { |
| 115 | odsParser.emitError(loc) << "explicit bit-vector type required" ; |
| 116 | return {}; |
| 117 | } |
| 118 | |
| 119 | unsigned width = llvm::cast<BitVectorType>(odsType).getWidth(); |
| 120 | |
| 121 | if (width > val.getBitWidth()) { |
| 122 | // sext is always safe here, even for unsigned values, because the |
| 123 | // parseOptionalInteger method will return something with a zero in the |
| 124 | // top bits if it is a positive number. |
| 125 | val = val.sext(width); |
| 126 | } else if (width < val.getBitWidth()) { |
| 127 | // The parser can return an unnecessarily wide result. |
| 128 | // This isn't a problem, but truncating off bits is bad. |
| 129 | unsigned neededBits = |
| 130 | val.isNegative() ? val.getSignificantBits() : val.getActiveBits(); |
| 131 | if (width < neededBits) { |
| 132 | odsParser.emitError(loc) |
| 133 | << "integer value out of range for given bit-vector type " << odsType; |
| 134 | return {}; |
| 135 | } |
| 136 | val = val.trunc(width); |
| 137 | } |
| 138 | |
| 139 | return BitVectorAttr::get(odsParser.getContext(), val); |
| 140 | } |
| 141 | |
| 142 | void BitVectorAttr::print(AsmPrinter &odsPrinter) const { |
| 143 | // This printer only works for the extended format where the MLIR |
| 144 | // infrastructure prints the type for us. This means, the attribute should |
| 145 | // never be used without `quantified` in an assembly format. |
| 146 | odsPrinter << "<" << getValue() << ">" ; |
| 147 | } |
| 148 | |
| 149 | Type BitVectorAttr::getType() const { |
| 150 | return BitVectorType::get(getContext(), getValue().getBitWidth()); |
| 151 | } |
| 152 | |
| 153 | //===----------------------------------------------------------------------===// |
| 154 | // ODS Boilerplate |
| 155 | //===----------------------------------------------------------------------===// |
| 156 | |
| 157 | #define GET_ATTRDEF_CLASSES |
| 158 | #include "mlir/Dialect/SMT/IR/SMTAttributes.cpp.inc" |
| 159 | |
| 160 | void SMTDialect::registerAttributes() { |
| 161 | addAttributes< |
| 162 | #define GET_ATTRDEF_LIST |
| 163 | #include "mlir/Dialect/SMT/IR/SMTAttributes.cpp.inc" |
| 164 | >(); |
| 165 | } |
| 166 | |