| 1 | //===- SMTAttributes.cpp - Implement SMT attributes -----------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/SMT/IR/SMTAttributes.h" |
| 10 | #include "mlir/Dialect/SMT/IR/SMTDialect.h" |
| 11 | #include "mlir/Dialect/SMT/IR/SMTTypes.h" |
| 12 | #include "mlir/IR/Builders.h" |
| 13 | #include "mlir/IR/DialectImplementation.h" |
| 14 | #include "llvm/ADT/TypeSwitch.h" |
| 15 | |
| 16 | using namespace mlir; |
| 17 | using namespace mlir::smt; |
| 18 | |
| 19 | //===----------------------------------------------------------------------===// |
| 20 | // BitVectorAttr |
| 21 | //===----------------------------------------------------------------------===// |
| 22 | |
| 23 | LogicalResult BitVectorAttr::verify( |
| 24 | function_ref<InFlightDiagnostic()> emitError, |
| 25 | APInt value) { // NOLINT(performance-unnecessary-value-param) |
| 26 | if (value.getBitWidth() < 1) |
| 27 | return emitError() << "bit-width must be at least 1, but got " |
| 28 | << value.getBitWidth(); |
| 29 | return success(); |
| 30 | } |
| 31 | |
| 32 | std::string BitVectorAttr::getValueAsString(bool prefix) const { |
| 33 | unsigned width = getValue().getBitWidth(); |
| 34 | SmallVector<char> toPrint; |
| 35 | StringRef pref = prefix ? "#" : "" ; |
| 36 | if (width % 4 == 0) { |
| 37 | getValue().toString(Str&: toPrint, Radix: 16, Signed: false, formatAsCLiteral: false, UpperCase: false); |
| 38 | // APInt's 'toString' omits leading zeros. However, those are critical here |
| 39 | // because they determine the bit-width of the bit-vector. |
| 40 | SmallVector<char> leadingZeros(width / 4 - toPrint.size(), '0'); |
| 41 | return (pref + "x" + Twine(leadingZeros) + toPrint).str(); |
| 42 | } |
| 43 | |
| 44 | getValue().toString(Str&: toPrint, Radix: 2, Signed: false, formatAsCLiteral: false, UpperCase: false); |
| 45 | // APInt's 'toString' omits leading zeros |
| 46 | SmallVector<char> leadingZeros(width - toPrint.size(), '0'); |
| 47 | return (pref + "b" + Twine(leadingZeros) + toPrint).str(); |
| 48 | } |
| 49 | |
| 50 | /// Parse an SMT-LIB formatted bit-vector string. |
| 51 | static FailureOr<APInt> |
| 52 | parseBitVectorString(function_ref<InFlightDiagnostic()> emitError, |
| 53 | StringRef value) { |
| 54 | if (value[0] != '#') |
| 55 | return emitError() << "expected '#'" ; |
| 56 | |
| 57 | if (value.size() < 3) |
| 58 | return emitError() << "expected at least one digit" ; |
| 59 | |
| 60 | if (value[1] == 'b') |
| 61 | return APInt(value.size() - 2, std::string(value.begin() + 2, value.end()), |
| 62 | 2); |
| 63 | |
| 64 | if (value[1] == 'x') |
| 65 | return APInt((value.size() - 2) * 4, |
| 66 | std::string(value.begin() + 2, value.end()), 16); |
| 67 | |
| 68 | return emitError() << "expected either 'b' or 'x'" ; |
| 69 | } |
| 70 | |
| 71 | BitVectorAttr BitVectorAttr::get(MLIRContext *context, StringRef value) { |
| 72 | auto maybeValue = parseBitVectorString(emitError: nullptr, value); |
| 73 | |
| 74 | assert(succeeded(maybeValue) && "string must have SMT-LIB format" ); |
| 75 | return Base::get(ctx: context, args&: *maybeValue); |
| 76 | } |
| 77 | |
| 78 | BitVectorAttr |
| 79 | BitVectorAttr::getChecked(function_ref<InFlightDiagnostic()> emitError, |
| 80 | MLIRContext *context, StringRef value) { |
| 81 | auto maybeValue = parseBitVectorString(emitError, value); |
| 82 | if (failed(Result: maybeValue)) |
| 83 | return {}; |
| 84 | |
| 85 | return Base::getChecked(emitErrorFn: emitError, ctx: context, args: *maybeValue); |
| 86 | } |
| 87 | |
| 88 | BitVectorAttr BitVectorAttr::get(MLIRContext *context, uint64_t value, |
| 89 | unsigned width) { |
| 90 | return Base::get(ctx: context, args: APInt(width, value)); |
| 91 | } |
| 92 | |
| 93 | BitVectorAttr |
| 94 | BitVectorAttr::getChecked(function_ref<InFlightDiagnostic()> emitError, |
| 95 | MLIRContext *context, uint64_t value, |
| 96 | unsigned width) { |
| 97 | if (width < 64 && value >= (UINT64_C(1) << width)) { |
| 98 | emitError() << "value does not fit in a bit-vector of desired width" ; |
| 99 | return {}; |
| 100 | } |
| 101 | return Base::getChecked(emitErrorFn: emitError, ctx: context, args: APInt(width, value)); |
| 102 | } |
| 103 | |
| 104 | Attribute BitVectorAttr::parse(AsmParser &odsParser, Type odsType) { |
| 105 | llvm::SMLoc loc = odsParser.getCurrentLocation(); |
| 106 | |
| 107 | APInt val; |
| 108 | if (odsParser.parseLess() || odsParser.parseInteger(result&: val) || |
| 109 | odsParser.parseGreater()) |
| 110 | return {}; |
| 111 | |
| 112 | // Requires the use of `quantified(<attr>)` in operation assembly formats. |
| 113 | if (!odsType || !llvm::isa<BitVectorType>(Val: odsType)) { |
| 114 | odsParser.emitError(loc) << "explicit bit-vector type required" ; |
| 115 | return {}; |
| 116 | } |
| 117 | |
| 118 | unsigned width = llvm::cast<BitVectorType>(Val&: odsType).getWidth(); |
| 119 | |
| 120 | if (width > val.getBitWidth()) { |
| 121 | // sext is always safe here, even for unsigned values, because the |
| 122 | // parseOptionalInteger method will return something with a zero in the |
| 123 | // top bits if it is a positive number. |
| 124 | val = val.sext(width); |
| 125 | } else if (width < val.getBitWidth()) { |
| 126 | // The parser can return an unnecessarily wide result. |
| 127 | // This isn't a problem, but truncating off bits is bad. |
| 128 | unsigned neededBits = |
| 129 | val.isNegative() ? val.getSignificantBits() : val.getActiveBits(); |
| 130 | if (width < neededBits) { |
| 131 | odsParser.emitError(loc) |
| 132 | << "integer value out of range for given bit-vector type " << odsType; |
| 133 | return {}; |
| 134 | } |
| 135 | val = val.trunc(width); |
| 136 | } |
| 137 | |
| 138 | return BitVectorAttr::get(context: odsParser.getContext(), value: val); |
| 139 | } |
| 140 | |
| 141 | void BitVectorAttr::print(AsmPrinter &odsPrinter) const { |
| 142 | // This printer only works for the extended format where the MLIR |
| 143 | // infrastructure prints the type for us. This means, the attribute should |
| 144 | // never be used without `quantified` in an assembly format. |
| 145 | odsPrinter << "<" << getValue() << ">" ; |
| 146 | } |
| 147 | |
| 148 | Type BitVectorAttr::getType() const { |
| 149 | return BitVectorType::get(context: getContext(), width: getValue().getBitWidth()); |
| 150 | } |
| 151 | |
| 152 | //===----------------------------------------------------------------------===// |
| 153 | // ODS Boilerplate |
| 154 | //===----------------------------------------------------------------------===// |
| 155 | |
| 156 | #define GET_ATTRDEF_CLASSES |
| 157 | #include "mlir/Dialect/SMT/IR/SMTAttributes.cpp.inc" |
| 158 | |
| 159 | void SMTDialect::registerAttributes() { |
| 160 | addAttributes< |
| 161 | #define GET_ATTRDEF_LIST |
| 162 | #include "mlir/Dialect/SMT/IR/SMTAttributes.cpp.inc" |
| 163 | >(); |
| 164 | } |
| 165 | |