1 | //===- SMTAttributes.cpp - Implement SMT attributes -----------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/SMT/IR/SMTAttributes.h" |
10 | #include "mlir/Dialect/SMT/IR/SMTDialect.h" |
11 | #include "mlir/Dialect/SMT/IR/SMTTypes.h" |
12 | #include "mlir/IR/Builders.h" |
13 | #include "mlir/IR/DialectImplementation.h" |
14 | #include "llvm/ADT/TypeSwitch.h" |
15 | #include "llvm/Support/Format.h" |
16 | |
17 | using namespace mlir; |
18 | using namespace mlir::smt; |
19 | |
20 | //===----------------------------------------------------------------------===// |
21 | // BitVectorAttr |
22 | //===----------------------------------------------------------------------===// |
23 | |
24 | LogicalResult BitVectorAttr::verify( |
25 | function_ref<InFlightDiagnostic()> emitError, |
26 | APInt value) { // NOLINT(performance-unnecessary-value-param) |
27 | if (value.getBitWidth() < 1) |
28 | return emitError() << "bit-width must be at least 1, but got " |
29 | << value.getBitWidth(); |
30 | return success(); |
31 | } |
32 | |
33 | std::string BitVectorAttr::getValueAsString(bool prefix) const { |
34 | unsigned width = getValue().getBitWidth(); |
35 | SmallVector<char> toPrint; |
36 | StringRef pref = prefix ? "#" : "" ; |
37 | if (width % 4 == 0) { |
38 | getValue().toString(toPrint, 16, false, false, false); |
39 | // APInt's 'toString' omits leading zeros. However, those are critical here |
40 | // because they determine the bit-width of the bit-vector. |
41 | SmallVector<char> leadingZeros(width / 4 - toPrint.size(), '0'); |
42 | return (pref + "x" + Twine(leadingZeros) + toPrint).str(); |
43 | } |
44 | |
45 | getValue().toString(toPrint, 2, false, false, false); |
46 | // APInt's 'toString' omits leading zeros |
47 | SmallVector<char> leadingZeros(width - toPrint.size(), '0'); |
48 | return (pref + "b" + Twine(leadingZeros) + toPrint).str(); |
49 | } |
50 | |
51 | /// Parse an SMT-LIB formatted bit-vector string. |
52 | static FailureOr<APInt> |
53 | parseBitVectorString(function_ref<InFlightDiagnostic()> emitError, |
54 | StringRef value) { |
55 | if (value[0] != '#') |
56 | return emitError() << "expected '#'" ; |
57 | |
58 | if (value.size() < 3) |
59 | return emitError() << "expected at least one digit" ; |
60 | |
61 | if (value[1] == 'b') |
62 | return APInt(value.size() - 2, std::string(value.begin() + 2, value.end()), |
63 | 2); |
64 | |
65 | if (value[1] == 'x') |
66 | return APInt((value.size() - 2) * 4, |
67 | std::string(value.begin() + 2, value.end()), 16); |
68 | |
69 | return emitError() << "expected either 'b' or 'x'" ; |
70 | } |
71 | |
72 | BitVectorAttr BitVectorAttr::get(MLIRContext *context, StringRef value) { |
73 | auto maybeValue = parseBitVectorString(nullptr, value); |
74 | |
75 | assert(succeeded(maybeValue) && "string must have SMT-LIB format" ); |
76 | return Base::get(context, *maybeValue); |
77 | } |
78 | |
79 | BitVectorAttr |
80 | BitVectorAttr::getChecked(function_ref<InFlightDiagnostic()> emitError, |
81 | MLIRContext *context, StringRef value) { |
82 | auto maybeValue = parseBitVectorString(emitError, value); |
83 | if (failed(maybeValue)) |
84 | return {}; |
85 | |
86 | return Base::getChecked(emitError, context, *maybeValue); |
87 | } |
88 | |
89 | BitVectorAttr BitVectorAttr::get(MLIRContext *context, uint64_t value, |
90 | unsigned width) { |
91 | return Base::get(context, APInt(width, value)); |
92 | } |
93 | |
94 | BitVectorAttr |
95 | BitVectorAttr::getChecked(function_ref<InFlightDiagnostic()> emitError, |
96 | MLIRContext *context, uint64_t value, |
97 | unsigned width) { |
98 | if (width < 64 && value >= (UINT64_C(1) << width)) { |
99 | emitError() << "value does not fit in a bit-vector of desired width" ; |
100 | return {}; |
101 | } |
102 | return Base::getChecked(emitError, context, APInt(width, value)); |
103 | } |
104 | |
105 | Attribute BitVectorAttr::parse(AsmParser &odsParser, Type odsType) { |
106 | llvm::SMLoc loc = odsParser.getCurrentLocation(); |
107 | |
108 | APInt val; |
109 | if (odsParser.parseLess() || odsParser.parseInteger(val) || |
110 | odsParser.parseGreater()) |
111 | return {}; |
112 | |
113 | // Requires the use of `quantified(<attr>)` in operation assembly formats. |
114 | if (!odsType || !llvm::isa<BitVectorType>(odsType)) { |
115 | odsParser.emitError(loc) << "explicit bit-vector type required" ; |
116 | return {}; |
117 | } |
118 | |
119 | unsigned width = llvm::cast<BitVectorType>(odsType).getWidth(); |
120 | |
121 | if (width > val.getBitWidth()) { |
122 | // sext is always safe here, even for unsigned values, because the |
123 | // parseOptionalInteger method will return something with a zero in the |
124 | // top bits if it is a positive number. |
125 | val = val.sext(width); |
126 | } else if (width < val.getBitWidth()) { |
127 | // The parser can return an unnecessarily wide result. |
128 | // This isn't a problem, but truncating off bits is bad. |
129 | unsigned neededBits = |
130 | val.isNegative() ? val.getSignificantBits() : val.getActiveBits(); |
131 | if (width < neededBits) { |
132 | odsParser.emitError(loc) |
133 | << "integer value out of range for given bit-vector type " << odsType; |
134 | return {}; |
135 | } |
136 | val = val.trunc(width); |
137 | } |
138 | |
139 | return BitVectorAttr::get(odsParser.getContext(), val); |
140 | } |
141 | |
142 | void BitVectorAttr::print(AsmPrinter &odsPrinter) const { |
143 | // This printer only works for the extended format where the MLIR |
144 | // infrastructure prints the type for us. This means, the attribute should |
145 | // never be used without `quantified` in an assembly format. |
146 | odsPrinter << "<" << getValue() << ">" ; |
147 | } |
148 | |
149 | Type BitVectorAttr::getType() const { |
150 | return BitVectorType::get(getContext(), getValue().getBitWidth()); |
151 | } |
152 | |
153 | //===----------------------------------------------------------------------===// |
154 | // ODS Boilerplate |
155 | //===----------------------------------------------------------------------===// |
156 | |
157 | #define GET_ATTRDEF_CLASSES |
158 | #include "mlir/Dialect/SMT/IR/SMTAttributes.cpp.inc" |
159 | |
160 | void SMTDialect::registerAttributes() { |
161 | addAttributes< |
162 | #define GET_ATTRDEF_LIST |
163 | #include "mlir/Dialect/SMT/IR/SMTAttributes.cpp.inc" |
164 | >(); |
165 | } |
166 | |