1 | //===- SMTOps.cpp ---------------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Dialect/SMT/IR/SMTOps.h" |
10 | #include "mlir/IR/Builders.h" |
11 | #include "mlir/IR/OpImplementation.h" |
12 | #include "llvm/ADT/APSInt.h" |
13 | |
14 | using namespace mlir; |
15 | using namespace smt; |
16 | using namespace mlir; |
17 | |
18 | //===----------------------------------------------------------------------===// |
19 | // BVConstantOp |
20 | //===----------------------------------------------------------------------===// |
21 | |
22 | LogicalResult BVConstantOp::inferReturnTypes( |
23 | mlir::MLIRContext *context, std::optional<mlir::Location> location, |
24 | ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, |
25 | ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions, |
26 | ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { |
27 | inferredReturnTypes.push_back( |
28 | properties.as<Properties *>()->getValue().getType()); |
29 | return success(); |
30 | } |
31 | |
32 | void BVConstantOp::getAsmResultNames( |
33 | function_ref<void(Value, StringRef)> setNameFn) { |
34 | SmallVector<char, 128> specialNameBuffer; |
35 | llvm::raw_svector_ostream specialName(specialNameBuffer); |
36 | specialName << "c" << getValue().getValue() << "_bv" |
37 | << getValue().getValue().getBitWidth(); |
38 | setNameFn(getResult(), specialName.str()); |
39 | } |
40 | |
41 | OpFoldResult BVConstantOp::fold(FoldAdaptor adaptor) { |
42 | assert(adaptor.getOperands().empty() && "constant has no operands" ); |
43 | return getValueAttr(); |
44 | } |
45 | |
46 | //===----------------------------------------------------------------------===// |
47 | // DeclareFunOp |
48 | //===----------------------------------------------------------------------===// |
49 | |
50 | void DeclareFunOp::getAsmResultNames( |
51 | function_ref<void(Value, StringRef)> setNameFn) { |
52 | setNameFn(getResult(), getNamePrefix().has_value() ? *getNamePrefix() : "" ); |
53 | } |
54 | |
55 | //===----------------------------------------------------------------------===// |
56 | // SolverOp |
57 | //===----------------------------------------------------------------------===// |
58 | |
59 | LogicalResult SolverOp::verifyRegions() { |
60 | if (getBody()->getTerminator()->getOperands().getTypes() != getResultTypes()) |
61 | return emitOpError() << "types of yielded values must match return values" ; |
62 | if (getBody()->getArgumentTypes() != getInputs().getTypes()) |
63 | return emitOpError() |
64 | << "block argument types must match the types of the 'inputs'" ; |
65 | |
66 | return success(); |
67 | } |
68 | |
69 | //===----------------------------------------------------------------------===// |
70 | // CheckOp |
71 | //===----------------------------------------------------------------------===// |
72 | |
73 | LogicalResult CheckOp::verifyRegions() { |
74 | if (getSatRegion().front().getTerminator()->getOperands().getTypes() != |
75 | getResultTypes()) |
76 | return emitOpError() << "types of yielded values in 'sat' region must " |
77 | "match return values" ; |
78 | if (getUnknownRegion().front().getTerminator()->getOperands().getTypes() != |
79 | getResultTypes()) |
80 | return emitOpError() << "types of yielded values in 'unknown' region must " |
81 | "match return values" ; |
82 | if (getUnsatRegion().front().getTerminator()->getOperands().getTypes() != |
83 | getResultTypes()) |
84 | return emitOpError() << "types of yielded values in 'unsat' region must " |
85 | "match return values" ; |
86 | |
87 | return success(); |
88 | } |
89 | |
90 | //===----------------------------------------------------------------------===// |
91 | // EqOp |
92 | //===----------------------------------------------------------------------===// |
93 | |
94 | static LogicalResult |
95 | parseSameOperandTypeVariadicToBoolOp(OpAsmParser &parser, |
96 | OperationState &result) { |
97 | SmallVector<OpAsmParser::UnresolvedOperand, 4> inputs; |
98 | SMLoc loc = parser.getCurrentLocation(); |
99 | Type type; |
100 | |
101 | if (parser.parseOperandList(result&: inputs) || |
102 | parser.parseOptionalAttrDict(result&: result.attributes) || parser.parseColon() || |
103 | parser.parseType(result&: type)) |
104 | return failure(); |
105 | |
106 | result.addTypes(BoolType::get(parser.getContext())); |
107 | if (parser.resolveOperands(operands&: inputs, types: SmallVector<Type>(inputs.size(), type), |
108 | loc, result&: result.operands)) |
109 | return failure(); |
110 | |
111 | return success(); |
112 | } |
113 | |
114 | ParseResult EqOp::parse(OpAsmParser &parser, OperationState &result) { |
115 | return parseSameOperandTypeVariadicToBoolOp(parser, result); |
116 | } |
117 | |
118 | void EqOp::print(OpAsmPrinter &printer) { |
119 | printer << ' ' << getInputs(); |
120 | printer.printOptionalAttrDict(getOperation()->getAttrs()); |
121 | printer << " : " << getInputs().front().getType(); |
122 | } |
123 | |
124 | LogicalResult EqOp::verify() { |
125 | if (getInputs().size() < 2) |
126 | return emitOpError() << "'inputs' must have at least size 2, but got " |
127 | << getInputs().size(); |
128 | |
129 | return success(); |
130 | } |
131 | |
132 | //===----------------------------------------------------------------------===// |
133 | // DistinctOp |
134 | //===----------------------------------------------------------------------===// |
135 | |
136 | ParseResult DistinctOp::parse(OpAsmParser &parser, OperationState &result) { |
137 | return parseSameOperandTypeVariadicToBoolOp(parser, result); |
138 | } |
139 | |
140 | void DistinctOp::print(OpAsmPrinter &printer) { |
141 | printer << ' ' << getInputs(); |
142 | printer.printOptionalAttrDict(getOperation()->getAttrs()); |
143 | printer << " : " << getInputs().front().getType(); |
144 | } |
145 | |
146 | LogicalResult DistinctOp::verify() { |
147 | if (getInputs().size() < 2) |
148 | return emitOpError() << "'inputs' must have at least size 2, but got " |
149 | << getInputs().size(); |
150 | |
151 | return success(); |
152 | } |
153 | |
154 | //===----------------------------------------------------------------------===// |
155 | // ExtractOp |
156 | //===----------------------------------------------------------------------===// |
157 | |
158 | LogicalResult ExtractOp::verify() { |
159 | unsigned rangeWidth = getType().getWidth(); |
160 | unsigned inputWidth = cast<BitVectorType>(getInput().getType()).getWidth(); |
161 | if (getLowBit() + rangeWidth > inputWidth) |
162 | return emitOpError("range to be extracted is too big, expected range " |
163 | "starting at index " ) |
164 | << getLowBit() << " of length " << rangeWidth |
165 | << " requires input width of at least " << (getLowBit() + rangeWidth) |
166 | << ", but the input width is only " << inputWidth; |
167 | return success(); |
168 | } |
169 | |
170 | //===----------------------------------------------------------------------===// |
171 | // ConcatOp |
172 | //===----------------------------------------------------------------------===// |
173 | |
174 | LogicalResult ConcatOp::inferReturnTypes( |
175 | MLIRContext *context, std::optional<Location> location, ValueRange operands, |
176 | DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, |
177 | SmallVectorImpl<Type> &inferredReturnTypes) { |
178 | inferredReturnTypes.push_back(BitVectorType::get( |
179 | context, cast<BitVectorType>(operands[0].getType()).getWidth() + |
180 | cast<BitVectorType>(operands[1].getType()).getWidth())); |
181 | return success(); |
182 | } |
183 | |
184 | //===----------------------------------------------------------------------===// |
185 | // RepeatOp |
186 | //===----------------------------------------------------------------------===// |
187 | |
188 | LogicalResult RepeatOp::verify() { |
189 | unsigned inputWidth = cast<BitVectorType>(getInput().getType()).getWidth(); |
190 | unsigned resultWidth = getType().getWidth(); |
191 | if (resultWidth % inputWidth != 0) |
192 | return emitOpError() << "result bit-vector width must be a multiple of the " |
193 | "input bit-vector width" ; |
194 | |
195 | return success(); |
196 | } |
197 | |
198 | unsigned RepeatOp::getCount() { |
199 | unsigned inputWidth = cast<BitVectorType>(getInput().getType()).getWidth(); |
200 | unsigned resultWidth = getType().getWidth(); |
201 | return resultWidth / inputWidth; |
202 | } |
203 | |
204 | void RepeatOp::build(OpBuilder &builder, OperationState &state, unsigned count, |
205 | Value input) { |
206 | unsigned inputWidth = cast<BitVectorType>(input.getType()).getWidth(); |
207 | Type resultTy = BitVectorType::get(builder.getContext(), inputWidth * count); |
208 | build(builder, state, resultTy, input); |
209 | } |
210 | |
211 | ParseResult RepeatOp::parse(OpAsmParser &parser, OperationState &result) { |
212 | OpAsmParser::UnresolvedOperand input; |
213 | Type inputType; |
214 | llvm::SMLoc countLoc = parser.getCurrentLocation(); |
215 | |
216 | APInt count; |
217 | if (parser.parseInteger(count) || parser.parseKeyword("times" )) |
218 | return failure(); |
219 | |
220 | if (count.isNonPositive()) |
221 | return parser.emitError(countLoc) << "integer must be positive" ; |
222 | |
223 | llvm::SMLoc inputLoc = parser.getCurrentLocation(); |
224 | if (parser.parseOperand(input) || |
225 | parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || |
226 | parser.parseType(inputType)) |
227 | return failure(); |
228 | |
229 | if (parser.resolveOperand(input, inputType, result.operands)) |
230 | return failure(); |
231 | |
232 | auto bvInputTy = dyn_cast<BitVectorType>(inputType); |
233 | if (!bvInputTy) |
234 | return parser.emitError(inputLoc) << "input must have bit-vector type" ; |
235 | |
236 | // Make sure no assertions can trigger and no silent overflows can happen |
237 | // Bit-width is stored as 'int64_t' parameter in 'BitVectorType' |
238 | const unsigned maxBw = 63; |
239 | if (count.getActiveBits() > maxBw) |
240 | return parser.emitError(countLoc) |
241 | << "integer must fit into " << maxBw << " bits" ; |
242 | |
243 | // Store multiplication in an APInt twice the size to not have any overflow |
244 | // and check if it can be truncated to 'maxBw' bits without cutting of |
245 | // important bits. |
246 | APInt resultBw = bvInputTy.getWidth() * count.zext(2 * maxBw); |
247 | if (resultBw.getActiveBits() > maxBw) |
248 | return parser.emitError(countLoc) |
249 | << "result bit-width (provided integer times bit-width of the input " |
250 | "type) must fit into " |
251 | << maxBw << " bits" ; |
252 | |
253 | Type resultTy = |
254 | BitVectorType::get(parser.getContext(), resultBw.getZExtValue()); |
255 | result.addTypes(resultTy); |
256 | return success(); |
257 | } |
258 | |
259 | void RepeatOp::print(OpAsmPrinter &printer) { |
260 | printer << " " << getCount() << " times " << getInput(); |
261 | printer.printOptionalAttrDict((*this)->getAttrs()); |
262 | printer << " : " << getInput().getType(); |
263 | } |
264 | |
265 | //===----------------------------------------------------------------------===// |
266 | // BoolConstantOp |
267 | //===----------------------------------------------------------------------===// |
268 | |
269 | void BoolConstantOp::getAsmResultNames( |
270 | function_ref<void(Value, StringRef)> setNameFn) { |
271 | setNameFn(getResult(), getValue() ? "true" : "false" ); |
272 | } |
273 | |
274 | OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) { |
275 | assert(adaptor.getOperands().empty() && "constant has no operands" ); |
276 | return getValueAttr(); |
277 | } |
278 | |
279 | //===----------------------------------------------------------------------===// |
280 | // IntConstantOp |
281 | //===----------------------------------------------------------------------===// |
282 | |
283 | void IntConstantOp::getAsmResultNames( |
284 | function_ref<void(Value, StringRef)> setNameFn) { |
285 | SmallVector<char, 32> specialNameBuffer; |
286 | llvm::raw_svector_ostream specialName(specialNameBuffer); |
287 | specialName << "c" << getValue(); |
288 | setNameFn(getResult(), specialName.str()); |
289 | } |
290 | |
291 | OpFoldResult IntConstantOp::fold(FoldAdaptor adaptor) { |
292 | assert(adaptor.getOperands().empty() && "constant has no operands" ); |
293 | return getValueAttr(); |
294 | } |
295 | |
296 | void IntConstantOp::print(OpAsmPrinter &p) { |
297 | p << " " << getValue(); |
298 | p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value" }); |
299 | } |
300 | |
301 | ParseResult IntConstantOp::parse(OpAsmParser &parser, OperationState &result) { |
302 | APInt value; |
303 | if (parser.parseInteger(value)) |
304 | return failure(); |
305 | |
306 | result.getOrAddProperties<Properties>().setValue( |
307 | IntegerAttr::get(parser.getContext(), APSInt(value))); |
308 | |
309 | if (parser.parseOptionalAttrDict(result.attributes)) |
310 | return failure(); |
311 | |
312 | result.addTypes(smt::IntType::get(parser.getContext())); |
313 | return success(); |
314 | } |
315 | |
316 | //===----------------------------------------------------------------------===// |
317 | // ForallOp |
318 | //===----------------------------------------------------------------------===// |
319 | |
320 | template <typename QuantifierOp> |
321 | static LogicalResult verifyQuantifierRegions(QuantifierOp op) { |
322 | if (op.getBoundVarNames() && |
323 | op.getBody().getNumArguments() != op.getBoundVarNames()->size()) |
324 | return op.emitOpError( |
325 | "number of bound variable names must match number of block arguments" ); |
326 | if (!llvm::all_of(op.getBody().getArgumentTypes(), isAnyNonFuncSMTValueType)) |
327 | return op.emitOpError() |
328 | << "bound variables must by any non-function SMT value" ; |
329 | |
330 | if (op.getBody().front().getTerminator()->getNumOperands() != 1) |
331 | return op.emitOpError("must have exactly one yielded value" ); |
332 | if (!isa<BoolType>( |
333 | op.getBody().front().getTerminator()->getOperand(0).getType())) |
334 | return op.emitOpError("yielded value must be of '!smt.bool' type" ); |
335 | |
336 | for (auto regionWithIndex : llvm::enumerate(op.getPatterns())) { |
337 | unsigned i = regionWithIndex.index(); |
338 | Region ®ion = regionWithIndex.value(); |
339 | |
340 | if (op.getBody().getArgumentTypes() != region.getArgumentTypes()) |
341 | return op.emitOpError() |
342 | << "block argument number and types of the 'body' " |
343 | "and 'patterns' region #" |
344 | << i << " must match" ; |
345 | if (region.front().getTerminator()->getNumOperands() < 1) |
346 | return op.emitOpError() << "'patterns' region #" << i |
347 | << " must have at least one yielded value" ; |
348 | |
349 | // All operations in the 'patterns' region must be SMT operations. |
350 | auto result = region.walk([&](Operation *childOp) { |
351 | if (!isa<SMTDialect>(childOp->getDialect())) { |
352 | auto diag = op.emitOpError() |
353 | << "the 'patterns' region #" << i |
354 | << " may only contain SMT dialect operations" ; |
355 | diag.attachNote(childOp->getLoc()) << "first non-SMT operation here" ; |
356 | return WalkResult::interrupt(); |
357 | } |
358 | |
359 | // There may be no quantifier (or other variable binding) operations in |
360 | // the 'patterns' region. |
361 | if (isa<ForallOp, ExistsOp>(childOp)) { |
362 | auto diag = op.emitOpError() << "the 'patterns' region #" << i |
363 | << " must not contain " |
364 | "any variable binding operations" ; |
365 | diag.attachNote(childOp->getLoc()) << "first violating operation here" ; |
366 | return WalkResult::interrupt(); |
367 | } |
368 | |
369 | return WalkResult::advance(); |
370 | }); |
371 | if (result.wasInterrupted()) |
372 | return failure(); |
373 | } |
374 | |
375 | return success(); |
376 | } |
377 | |
378 | template <typename Properties> |
379 | static void buildQuantifier( |
380 | OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes, |
381 | function_ref<Value(OpBuilder &, Location, ValueRange)> bodyBuilder, |
382 | std::optional<ArrayRef<StringRef>> boundVarNames, |
383 | function_ref<ValueRange(OpBuilder &, Location, ValueRange)> patternBuilder, |
384 | uint32_t weight, bool noPattern) { |
385 | odsState.addTypes(BoolType::get(odsBuilder.getContext())); |
386 | if (weight != 0) |
387 | odsState.getOrAddProperties<Properties>().weight = |
388 | odsBuilder.getIntegerAttr(odsBuilder.getIntegerType(32), weight); |
389 | if (noPattern) |
390 | odsState.getOrAddProperties<Properties>().noPattern = |
391 | odsBuilder.getUnitAttr(); |
392 | if (boundVarNames.has_value()) { |
393 | SmallVector<Attribute> boundVarNamesList; |
394 | for (StringRef str : *boundVarNames) |
395 | boundVarNamesList.emplace_back(Args: odsBuilder.getStringAttr(str)); |
396 | odsState.getOrAddProperties<Properties>().boundVarNames = |
397 | odsBuilder.getArrayAttr(boundVarNamesList); |
398 | } |
399 | { |
400 | OpBuilder::InsertionGuard guard(odsBuilder); |
401 | Region *region = odsState.addRegion(); |
402 | Block *block = odsBuilder.createBlock(parent: region); |
403 | block->addArguments( |
404 | types: boundVarTypes, |
405 | locs: SmallVector<Location>(boundVarTypes.size(), odsState.location)); |
406 | Value returnVal = |
407 | bodyBuilder(odsBuilder, odsState.location, block->getArguments()); |
408 | odsBuilder.create<smt::YieldOp>(odsState.location, returnVal); |
409 | } |
410 | if (patternBuilder) { |
411 | Region *region = odsState.addRegion(); |
412 | OpBuilder::InsertionGuard guard(odsBuilder); |
413 | Block *block = odsBuilder.createBlock(parent: region); |
414 | block->addArguments( |
415 | types: boundVarTypes, |
416 | locs: SmallVector<Location>(boundVarTypes.size(), odsState.location)); |
417 | ValueRange returnVals = |
418 | patternBuilder(odsBuilder, odsState.location, block->getArguments()); |
419 | odsBuilder.create<smt::YieldOp>(odsState.location, returnVals); |
420 | } |
421 | } |
422 | |
423 | LogicalResult ForallOp::verify() { |
424 | if (!getPatterns().empty() && getNoPattern()) |
425 | return emitOpError() << "patterns and the no_pattern attribute must not be " |
426 | "specified at the same time" ; |
427 | |
428 | return success(); |
429 | } |
430 | |
431 | LogicalResult ForallOp::verifyRegions() { |
432 | return verifyQuantifierRegions(*this); |
433 | } |
434 | |
435 | void ForallOp::build( |
436 | OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes, |
437 | function_ref<Value(OpBuilder &, Location, ValueRange)> bodyBuilder, |
438 | std::optional<ArrayRef<StringRef>> boundVarNames, |
439 | function_ref<ValueRange(OpBuilder &, Location, ValueRange)> patternBuilder, |
440 | uint32_t weight, bool noPattern) { |
441 | buildQuantifier<Properties>(odsBuilder, odsState, boundVarTypes, bodyBuilder, |
442 | boundVarNames, patternBuilder, weight, noPattern); |
443 | } |
444 | |
445 | //===----------------------------------------------------------------------===// |
446 | // ExistsOp |
447 | //===----------------------------------------------------------------------===// |
448 | |
449 | LogicalResult ExistsOp::verify() { |
450 | if (!getPatterns().empty() && getNoPattern()) |
451 | return emitOpError() << "patterns and the no_pattern attribute must not be " |
452 | "specified at the same time" ; |
453 | |
454 | return success(); |
455 | } |
456 | |
457 | LogicalResult ExistsOp::verifyRegions() { |
458 | return verifyQuantifierRegions(*this); |
459 | } |
460 | |
461 | void ExistsOp::build( |
462 | OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes, |
463 | function_ref<Value(OpBuilder &, Location, ValueRange)> bodyBuilder, |
464 | std::optional<ArrayRef<StringRef>> boundVarNames, |
465 | function_ref<ValueRange(OpBuilder &, Location, ValueRange)> patternBuilder, |
466 | uint32_t weight, bool noPattern) { |
467 | buildQuantifier<Properties>(odsBuilder, odsState, boundVarTypes, bodyBuilder, |
468 | boundVarNames, patternBuilder, weight, noPattern); |
469 | } |
470 | |
471 | #define GET_OP_CLASSES |
472 | #include "mlir/Dialect/SMT/IR/SMT.cpp.inc" |
473 | |