1//===- SMTAttributes.cpp - Implement SMT attributes -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/SMT/IR/SMTAttributes.h"
10#include "mlir/Dialect/SMT/IR/SMTDialect.h"
11#include "mlir/Dialect/SMT/IR/SMTTypes.h"
12#include "mlir/IR/Builders.h"
13#include "mlir/IR/DialectImplementation.h"
14#include "llvm/ADT/TypeSwitch.h"
15#include "llvm/Support/Format.h"
16
17using namespace mlir;
18using namespace mlir::smt;
19
20//===----------------------------------------------------------------------===//
21// BitVectorAttr
22//===----------------------------------------------------------------------===//
23
24LogicalResult BitVectorAttr::verify(
25 function_ref<InFlightDiagnostic()> emitError,
26 APInt value) { // NOLINT(performance-unnecessary-value-param)
27 if (value.getBitWidth() < 1)
28 return emitError() << "bit-width must be at least 1, but got "
29 << value.getBitWidth();
30 return success();
31}
32
33std::string BitVectorAttr::getValueAsString(bool prefix) const {
34 unsigned width = getValue().getBitWidth();
35 SmallVector<char> toPrint;
36 StringRef pref = prefix ? "#" : "";
37 if (width % 4 == 0) {
38 getValue().toString(toPrint, 16, false, false, false);
39 // APInt's 'toString' omits leading zeros. However, those are critical here
40 // because they determine the bit-width of the bit-vector.
41 SmallVector<char> leadingZeros(width / 4 - toPrint.size(), '0');
42 return (pref + "x" + Twine(leadingZeros) + toPrint).str();
43 }
44
45 getValue().toString(toPrint, 2, false, false, false);
46 // APInt's 'toString' omits leading zeros
47 SmallVector<char> leadingZeros(width - toPrint.size(), '0');
48 return (pref + "b" + Twine(leadingZeros) + toPrint).str();
49}
50
51/// Parse an SMT-LIB formatted bit-vector string.
52static FailureOr<APInt>
53parseBitVectorString(function_ref<InFlightDiagnostic()> emitError,
54 StringRef value) {
55 if (value[0] != '#')
56 return emitError() << "expected '#'";
57
58 if (value.size() < 3)
59 return emitError() << "expected at least one digit";
60
61 if (value[1] == 'b')
62 return APInt(value.size() - 2, std::string(value.begin() + 2, value.end()),
63 2);
64
65 if (value[1] == 'x')
66 return APInt((value.size() - 2) * 4,
67 std::string(value.begin() + 2, value.end()), 16);
68
69 return emitError() << "expected either 'b' or 'x'";
70}
71
72BitVectorAttr BitVectorAttr::get(MLIRContext *context, StringRef value) {
73 auto maybeValue = parseBitVectorString(nullptr, value);
74
75 assert(succeeded(maybeValue) && "string must have SMT-LIB format");
76 return Base::get(context, *maybeValue);
77}
78
79BitVectorAttr
80BitVectorAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
81 MLIRContext *context, StringRef value) {
82 auto maybeValue = parseBitVectorString(emitError, value);
83 if (failed(maybeValue))
84 return {};
85
86 return Base::getChecked(emitError, context, *maybeValue);
87}
88
89BitVectorAttr BitVectorAttr::get(MLIRContext *context, uint64_t value,
90 unsigned width) {
91 return Base::get(context, APInt(width, value));
92}
93
94BitVectorAttr
95BitVectorAttr::getChecked(function_ref<InFlightDiagnostic()> emitError,
96 MLIRContext *context, uint64_t value,
97 unsigned width) {
98 if (width < 64 && value >= (UINT64_C(1) << width)) {
99 emitError() << "value does not fit in a bit-vector of desired width";
100 return {};
101 }
102 return Base::getChecked(emitError, context, APInt(width, value));
103}
104
105Attribute BitVectorAttr::parse(AsmParser &odsParser, Type odsType) {
106 llvm::SMLoc loc = odsParser.getCurrentLocation();
107
108 APInt val;
109 if (odsParser.parseLess() || odsParser.parseInteger(val) ||
110 odsParser.parseGreater())
111 return {};
112
113 // Requires the use of `quantified(<attr>)` in operation assembly formats.
114 if (!odsType || !llvm::isa<BitVectorType>(odsType)) {
115 odsParser.emitError(loc) << "explicit bit-vector type required";
116 return {};
117 }
118
119 unsigned width = llvm::cast<BitVectorType>(odsType).getWidth();
120
121 if (width > val.getBitWidth()) {
122 // sext is always safe here, even for unsigned values, because the
123 // parseOptionalInteger method will return something with a zero in the
124 // top bits if it is a positive number.
125 val = val.sext(width);
126 } else if (width < val.getBitWidth()) {
127 // The parser can return an unnecessarily wide result.
128 // This isn't a problem, but truncating off bits is bad.
129 unsigned neededBits =
130 val.isNegative() ? val.getSignificantBits() : val.getActiveBits();
131 if (width < neededBits) {
132 odsParser.emitError(loc)
133 << "integer value out of range for given bit-vector type " << odsType;
134 return {};
135 }
136 val = val.trunc(width);
137 }
138
139 return BitVectorAttr::get(odsParser.getContext(), val);
140}
141
142void BitVectorAttr::print(AsmPrinter &odsPrinter) const {
143 // This printer only works for the extended format where the MLIR
144 // infrastructure prints the type for us. This means, the attribute should
145 // never be used without `quantified` in an assembly format.
146 odsPrinter << "<" << getValue() << ">";
147}
148
149Type BitVectorAttr::getType() const {
150 return BitVectorType::get(getContext(), getValue().getBitWidth());
151}
152
153//===----------------------------------------------------------------------===//
154// ODS Boilerplate
155//===----------------------------------------------------------------------===//
156
157#define GET_ATTRDEF_CLASSES
158#include "mlir/Dialect/SMT/IR/SMTAttributes.cpp.inc"
159
160void SMTDialect::registerAttributes() {
161 addAttributes<
162#define GET_ATTRDEF_LIST
163#include "mlir/Dialect/SMT/IR/SMTAttributes.cpp.inc"
164 >();
165}
166

source code of mlir/lib/Dialect/SMT/IR/SMTAttributes.cpp