| 1 | //===- SMTOps.cpp ---------------------------------------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "mlir/Dialect/SMT/IR/SMTOps.h" |
| 10 | #include "mlir/IR/Builders.h" |
| 11 | #include "mlir/IR/OpImplementation.h" |
| 12 | #include "llvm/ADT/APSInt.h" |
| 13 | |
| 14 | using namespace mlir; |
| 15 | using namespace smt; |
| 16 | using namespace mlir; |
| 17 | |
| 18 | //===----------------------------------------------------------------------===// |
| 19 | // BVConstantOp |
| 20 | //===----------------------------------------------------------------------===// |
| 21 | |
| 22 | LogicalResult BVConstantOp::inferReturnTypes( |
| 23 | mlir::MLIRContext *context, std::optional<mlir::Location> location, |
| 24 | ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, |
| 25 | ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions, |
| 26 | ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) { |
| 27 | inferredReturnTypes.push_back( |
| 28 | properties.as<Properties *>()->getValue().getType()); |
| 29 | return success(); |
| 30 | } |
| 31 | |
| 32 | void BVConstantOp::getAsmResultNames( |
| 33 | function_ref<void(Value, StringRef)> setNameFn) { |
| 34 | SmallVector<char, 128> specialNameBuffer; |
| 35 | llvm::raw_svector_ostream specialName(specialNameBuffer); |
| 36 | specialName << "c" << getValue().getValue() << "_bv" |
| 37 | << getValue().getValue().getBitWidth(); |
| 38 | setNameFn(getResult(), specialName.str()); |
| 39 | } |
| 40 | |
| 41 | OpFoldResult BVConstantOp::fold(FoldAdaptor adaptor) { |
| 42 | assert(adaptor.getOperands().empty() && "constant has no operands" ); |
| 43 | return getValueAttr(); |
| 44 | } |
| 45 | |
| 46 | //===----------------------------------------------------------------------===// |
| 47 | // DeclareFunOp |
| 48 | //===----------------------------------------------------------------------===// |
| 49 | |
| 50 | void DeclareFunOp::getAsmResultNames( |
| 51 | function_ref<void(Value, StringRef)> setNameFn) { |
| 52 | setNameFn(getResult(), getNamePrefix().has_value() ? *getNamePrefix() : "" ); |
| 53 | } |
| 54 | |
| 55 | //===----------------------------------------------------------------------===// |
| 56 | // SolverOp |
| 57 | //===----------------------------------------------------------------------===// |
| 58 | |
| 59 | LogicalResult SolverOp::verifyRegions() { |
| 60 | if (getBody()->getTerminator()->getOperands().getTypes() != getResultTypes()) |
| 61 | return emitOpError() << "types of yielded values must match return values" ; |
| 62 | if (getBody()->getArgumentTypes() != getInputs().getTypes()) |
| 63 | return emitOpError() |
| 64 | << "block argument types must match the types of the 'inputs'" ; |
| 65 | |
| 66 | return success(); |
| 67 | } |
| 68 | |
| 69 | //===----------------------------------------------------------------------===// |
| 70 | // CheckOp |
| 71 | //===----------------------------------------------------------------------===// |
| 72 | |
| 73 | LogicalResult CheckOp::verifyRegions() { |
| 74 | if (getSatRegion().front().getTerminator()->getOperands().getTypes() != |
| 75 | getResultTypes()) |
| 76 | return emitOpError() << "types of yielded values in 'sat' region must " |
| 77 | "match return values" ; |
| 78 | if (getUnknownRegion().front().getTerminator()->getOperands().getTypes() != |
| 79 | getResultTypes()) |
| 80 | return emitOpError() << "types of yielded values in 'unknown' region must " |
| 81 | "match return values" ; |
| 82 | if (getUnsatRegion().front().getTerminator()->getOperands().getTypes() != |
| 83 | getResultTypes()) |
| 84 | return emitOpError() << "types of yielded values in 'unsat' region must " |
| 85 | "match return values" ; |
| 86 | |
| 87 | return success(); |
| 88 | } |
| 89 | |
| 90 | //===----------------------------------------------------------------------===// |
| 91 | // EqOp |
| 92 | //===----------------------------------------------------------------------===// |
| 93 | |
| 94 | static LogicalResult |
| 95 | parseSameOperandTypeVariadicToBoolOp(OpAsmParser &parser, |
| 96 | OperationState &result) { |
| 97 | SmallVector<OpAsmParser::UnresolvedOperand, 4> inputs; |
| 98 | SMLoc loc = parser.getCurrentLocation(); |
| 99 | Type type; |
| 100 | |
| 101 | if (parser.parseOperandList(result&: inputs) || |
| 102 | parser.parseOptionalAttrDict(result&: result.attributes) || parser.parseColon() || |
| 103 | parser.parseType(result&: type)) |
| 104 | return failure(); |
| 105 | |
| 106 | result.addTypes(BoolType::get(parser.getContext())); |
| 107 | if (parser.resolveOperands(operands&: inputs, types: SmallVector<Type>(inputs.size(), type), |
| 108 | loc, result&: result.operands)) |
| 109 | return failure(); |
| 110 | |
| 111 | return success(); |
| 112 | } |
| 113 | |
| 114 | ParseResult EqOp::parse(OpAsmParser &parser, OperationState &result) { |
| 115 | return parseSameOperandTypeVariadicToBoolOp(parser, result); |
| 116 | } |
| 117 | |
| 118 | void EqOp::print(OpAsmPrinter &printer) { |
| 119 | printer << ' ' << getInputs(); |
| 120 | printer.printOptionalAttrDict(getOperation()->getAttrs()); |
| 121 | printer << " : " << getInputs().front().getType(); |
| 122 | } |
| 123 | |
| 124 | LogicalResult EqOp::verify() { |
| 125 | if (getInputs().size() < 2) |
| 126 | return emitOpError() << "'inputs' must have at least size 2, but got " |
| 127 | << getInputs().size(); |
| 128 | |
| 129 | return success(); |
| 130 | } |
| 131 | |
| 132 | //===----------------------------------------------------------------------===// |
| 133 | // DistinctOp |
| 134 | //===----------------------------------------------------------------------===// |
| 135 | |
| 136 | ParseResult DistinctOp::parse(OpAsmParser &parser, OperationState &result) { |
| 137 | return parseSameOperandTypeVariadicToBoolOp(parser, result); |
| 138 | } |
| 139 | |
| 140 | void DistinctOp::print(OpAsmPrinter &printer) { |
| 141 | printer << ' ' << getInputs(); |
| 142 | printer.printOptionalAttrDict(getOperation()->getAttrs()); |
| 143 | printer << " : " << getInputs().front().getType(); |
| 144 | } |
| 145 | |
| 146 | LogicalResult DistinctOp::verify() { |
| 147 | if (getInputs().size() < 2) |
| 148 | return emitOpError() << "'inputs' must have at least size 2, but got " |
| 149 | << getInputs().size(); |
| 150 | |
| 151 | return success(); |
| 152 | } |
| 153 | |
| 154 | //===----------------------------------------------------------------------===// |
| 155 | // ExtractOp |
| 156 | //===----------------------------------------------------------------------===// |
| 157 | |
| 158 | LogicalResult ExtractOp::verify() { |
| 159 | unsigned rangeWidth = getType().getWidth(); |
| 160 | unsigned inputWidth = cast<BitVectorType>(getInput().getType()).getWidth(); |
| 161 | if (getLowBit() + rangeWidth > inputWidth) |
| 162 | return emitOpError("range to be extracted is too big, expected range " |
| 163 | "starting at index " ) |
| 164 | << getLowBit() << " of length " << rangeWidth |
| 165 | << " requires input width of at least " << (getLowBit() + rangeWidth) |
| 166 | << ", but the input width is only " << inputWidth; |
| 167 | return success(); |
| 168 | } |
| 169 | |
| 170 | //===----------------------------------------------------------------------===// |
| 171 | // ConcatOp |
| 172 | //===----------------------------------------------------------------------===// |
| 173 | |
| 174 | LogicalResult ConcatOp::inferReturnTypes( |
| 175 | MLIRContext *context, std::optional<Location> location, ValueRange operands, |
| 176 | DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, |
| 177 | SmallVectorImpl<Type> &inferredReturnTypes) { |
| 178 | inferredReturnTypes.push_back(BitVectorType::get( |
| 179 | context, cast<BitVectorType>(operands[0].getType()).getWidth() + |
| 180 | cast<BitVectorType>(operands[1].getType()).getWidth())); |
| 181 | return success(); |
| 182 | } |
| 183 | |
| 184 | //===----------------------------------------------------------------------===// |
| 185 | // RepeatOp |
| 186 | //===----------------------------------------------------------------------===// |
| 187 | |
| 188 | LogicalResult RepeatOp::verify() { |
| 189 | unsigned inputWidth = cast<BitVectorType>(getInput().getType()).getWidth(); |
| 190 | unsigned resultWidth = getType().getWidth(); |
| 191 | if (resultWidth % inputWidth != 0) |
| 192 | return emitOpError() << "result bit-vector width must be a multiple of the " |
| 193 | "input bit-vector width" ; |
| 194 | |
| 195 | return success(); |
| 196 | } |
| 197 | |
| 198 | unsigned RepeatOp::getCount() { |
| 199 | unsigned inputWidth = cast<BitVectorType>(getInput().getType()).getWidth(); |
| 200 | unsigned resultWidth = getType().getWidth(); |
| 201 | return resultWidth / inputWidth; |
| 202 | } |
| 203 | |
| 204 | void RepeatOp::build(OpBuilder &builder, OperationState &state, unsigned count, |
| 205 | Value input) { |
| 206 | unsigned inputWidth = cast<BitVectorType>(input.getType()).getWidth(); |
| 207 | Type resultTy = BitVectorType::get(builder.getContext(), inputWidth * count); |
| 208 | build(builder, state, resultTy, input); |
| 209 | } |
| 210 | |
| 211 | ParseResult RepeatOp::parse(OpAsmParser &parser, OperationState &result) { |
| 212 | OpAsmParser::UnresolvedOperand input; |
| 213 | Type inputType; |
| 214 | llvm::SMLoc countLoc = parser.getCurrentLocation(); |
| 215 | |
| 216 | APInt count; |
| 217 | if (parser.parseInteger(count) || parser.parseKeyword("times" )) |
| 218 | return failure(); |
| 219 | |
| 220 | if (count.isNonPositive()) |
| 221 | return parser.emitError(countLoc) << "integer must be positive" ; |
| 222 | |
| 223 | llvm::SMLoc inputLoc = parser.getCurrentLocation(); |
| 224 | if (parser.parseOperand(input) || |
| 225 | parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || |
| 226 | parser.parseType(inputType)) |
| 227 | return failure(); |
| 228 | |
| 229 | if (parser.resolveOperand(input, inputType, result.operands)) |
| 230 | return failure(); |
| 231 | |
| 232 | auto bvInputTy = dyn_cast<BitVectorType>(inputType); |
| 233 | if (!bvInputTy) |
| 234 | return parser.emitError(inputLoc) << "input must have bit-vector type" ; |
| 235 | |
| 236 | // Make sure no assertions can trigger and no silent overflows can happen |
| 237 | // Bit-width is stored as 'int64_t' parameter in 'BitVectorType' |
| 238 | const unsigned maxBw = 63; |
| 239 | if (count.getActiveBits() > maxBw) |
| 240 | return parser.emitError(countLoc) |
| 241 | << "integer must fit into " << maxBw << " bits" ; |
| 242 | |
| 243 | // Store multiplication in an APInt twice the size to not have any overflow |
| 244 | // and check if it can be truncated to 'maxBw' bits without cutting of |
| 245 | // important bits. |
| 246 | APInt resultBw = bvInputTy.getWidth() * count.zext(2 * maxBw); |
| 247 | if (resultBw.getActiveBits() > maxBw) |
| 248 | return parser.emitError(countLoc) |
| 249 | << "result bit-width (provided integer times bit-width of the input " |
| 250 | "type) must fit into " |
| 251 | << maxBw << " bits" ; |
| 252 | |
| 253 | Type resultTy = |
| 254 | BitVectorType::get(parser.getContext(), resultBw.getZExtValue()); |
| 255 | result.addTypes(resultTy); |
| 256 | return success(); |
| 257 | } |
| 258 | |
| 259 | void RepeatOp::print(OpAsmPrinter &printer) { |
| 260 | printer << " " << getCount() << " times " << getInput(); |
| 261 | printer.printOptionalAttrDict((*this)->getAttrs()); |
| 262 | printer << " : " << getInput().getType(); |
| 263 | } |
| 264 | |
| 265 | //===----------------------------------------------------------------------===// |
| 266 | // BoolConstantOp |
| 267 | //===----------------------------------------------------------------------===// |
| 268 | |
| 269 | void BoolConstantOp::getAsmResultNames( |
| 270 | function_ref<void(Value, StringRef)> setNameFn) { |
| 271 | setNameFn(getResult(), getValue() ? "true" : "false" ); |
| 272 | } |
| 273 | |
| 274 | OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) { |
| 275 | assert(adaptor.getOperands().empty() && "constant has no operands" ); |
| 276 | return getValueAttr(); |
| 277 | } |
| 278 | |
| 279 | //===----------------------------------------------------------------------===// |
| 280 | // IntConstantOp |
| 281 | //===----------------------------------------------------------------------===// |
| 282 | |
| 283 | void IntConstantOp::getAsmResultNames( |
| 284 | function_ref<void(Value, StringRef)> setNameFn) { |
| 285 | SmallVector<char, 32> specialNameBuffer; |
| 286 | llvm::raw_svector_ostream specialName(specialNameBuffer); |
| 287 | specialName << "c" << getValue(); |
| 288 | setNameFn(getResult(), specialName.str()); |
| 289 | } |
| 290 | |
| 291 | OpFoldResult IntConstantOp::fold(FoldAdaptor adaptor) { |
| 292 | assert(adaptor.getOperands().empty() && "constant has no operands" ); |
| 293 | return getValueAttr(); |
| 294 | } |
| 295 | |
| 296 | void IntConstantOp::print(OpAsmPrinter &p) { |
| 297 | p << " " << getValue(); |
| 298 | p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"value" }); |
| 299 | } |
| 300 | |
| 301 | ParseResult IntConstantOp::parse(OpAsmParser &parser, OperationState &result) { |
| 302 | APInt value; |
| 303 | if (parser.parseInteger(value)) |
| 304 | return failure(); |
| 305 | |
| 306 | result.getOrAddProperties<Properties>().setValue( |
| 307 | IntegerAttr::get(parser.getContext(), APSInt(value))); |
| 308 | |
| 309 | if (parser.parseOptionalAttrDict(result.attributes)) |
| 310 | return failure(); |
| 311 | |
| 312 | result.addTypes(smt::IntType::get(parser.getContext())); |
| 313 | return success(); |
| 314 | } |
| 315 | |
| 316 | //===----------------------------------------------------------------------===// |
| 317 | // ForallOp |
| 318 | //===----------------------------------------------------------------------===// |
| 319 | |
| 320 | template <typename QuantifierOp> |
| 321 | static LogicalResult verifyQuantifierRegions(QuantifierOp op) { |
| 322 | if (op.getBoundVarNames() && |
| 323 | op.getBody().getNumArguments() != op.getBoundVarNames()->size()) |
| 324 | return op.emitOpError( |
| 325 | "number of bound variable names must match number of block arguments" ); |
| 326 | if (!llvm::all_of(op.getBody().getArgumentTypes(), isAnyNonFuncSMTValueType)) |
| 327 | return op.emitOpError() |
| 328 | << "bound variables must by any non-function SMT value" ; |
| 329 | |
| 330 | if (op.getBody().front().getTerminator()->getNumOperands() != 1) |
| 331 | return op.emitOpError("must have exactly one yielded value" ); |
| 332 | if (!isa<BoolType>( |
| 333 | op.getBody().front().getTerminator()->getOperand(0).getType())) |
| 334 | return op.emitOpError("yielded value must be of '!smt.bool' type" ); |
| 335 | |
| 336 | for (auto regionWithIndex : llvm::enumerate(op.getPatterns())) { |
| 337 | unsigned i = regionWithIndex.index(); |
| 338 | Region ®ion = regionWithIndex.value(); |
| 339 | |
| 340 | if (op.getBody().getArgumentTypes() != region.getArgumentTypes()) |
| 341 | return op.emitOpError() |
| 342 | << "block argument number and types of the 'body' " |
| 343 | "and 'patterns' region #" |
| 344 | << i << " must match" ; |
| 345 | if (region.front().getTerminator()->getNumOperands() < 1) |
| 346 | return op.emitOpError() << "'patterns' region #" << i |
| 347 | << " must have at least one yielded value" ; |
| 348 | |
| 349 | // All operations in the 'patterns' region must be SMT operations. |
| 350 | auto result = region.walk([&](Operation *childOp) { |
| 351 | if (!isa<SMTDialect>(childOp->getDialect())) { |
| 352 | auto diag = op.emitOpError() |
| 353 | << "the 'patterns' region #" << i |
| 354 | << " may only contain SMT dialect operations" ; |
| 355 | diag.attachNote(childOp->getLoc()) << "first non-SMT operation here" ; |
| 356 | return WalkResult::interrupt(); |
| 357 | } |
| 358 | |
| 359 | // There may be no quantifier (or other variable binding) operations in |
| 360 | // the 'patterns' region. |
| 361 | if (isa<ForallOp, ExistsOp>(childOp)) { |
| 362 | auto diag = op.emitOpError() << "the 'patterns' region #" << i |
| 363 | << " must not contain " |
| 364 | "any variable binding operations" ; |
| 365 | diag.attachNote(childOp->getLoc()) << "first violating operation here" ; |
| 366 | return WalkResult::interrupt(); |
| 367 | } |
| 368 | |
| 369 | return WalkResult::advance(); |
| 370 | }); |
| 371 | if (result.wasInterrupted()) |
| 372 | return failure(); |
| 373 | } |
| 374 | |
| 375 | return success(); |
| 376 | } |
| 377 | |
| 378 | template <typename Properties> |
| 379 | static void buildQuantifier( |
| 380 | OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes, |
| 381 | function_ref<Value(OpBuilder &, Location, ValueRange)> bodyBuilder, |
| 382 | std::optional<ArrayRef<StringRef>> boundVarNames, |
| 383 | function_ref<ValueRange(OpBuilder &, Location, ValueRange)> patternBuilder, |
| 384 | uint32_t weight, bool noPattern) { |
| 385 | odsState.addTypes(BoolType::get(odsBuilder.getContext())); |
| 386 | if (weight != 0) |
| 387 | odsState.getOrAddProperties<Properties>().weight = |
| 388 | odsBuilder.getIntegerAttr(odsBuilder.getIntegerType(32), weight); |
| 389 | if (noPattern) |
| 390 | odsState.getOrAddProperties<Properties>().noPattern = |
| 391 | odsBuilder.getUnitAttr(); |
| 392 | if (boundVarNames.has_value()) { |
| 393 | SmallVector<Attribute> boundVarNamesList; |
| 394 | for (StringRef str : *boundVarNames) |
| 395 | boundVarNamesList.emplace_back(Args: odsBuilder.getStringAttr(str)); |
| 396 | odsState.getOrAddProperties<Properties>().boundVarNames = |
| 397 | odsBuilder.getArrayAttr(boundVarNamesList); |
| 398 | } |
| 399 | { |
| 400 | OpBuilder::InsertionGuard guard(odsBuilder); |
| 401 | Region *region = odsState.addRegion(); |
| 402 | Block *block = odsBuilder.createBlock(parent: region); |
| 403 | block->addArguments( |
| 404 | types: boundVarTypes, |
| 405 | locs: SmallVector<Location>(boundVarTypes.size(), odsState.location)); |
| 406 | Value returnVal = |
| 407 | bodyBuilder(odsBuilder, odsState.location, block->getArguments()); |
| 408 | odsBuilder.create<smt::YieldOp>(odsState.location, returnVal); |
| 409 | } |
| 410 | if (patternBuilder) { |
| 411 | Region *region = odsState.addRegion(); |
| 412 | OpBuilder::InsertionGuard guard(odsBuilder); |
| 413 | Block *block = odsBuilder.createBlock(parent: region); |
| 414 | block->addArguments( |
| 415 | types: boundVarTypes, |
| 416 | locs: SmallVector<Location>(boundVarTypes.size(), odsState.location)); |
| 417 | ValueRange returnVals = |
| 418 | patternBuilder(odsBuilder, odsState.location, block->getArguments()); |
| 419 | odsBuilder.create<smt::YieldOp>(odsState.location, returnVals); |
| 420 | } |
| 421 | } |
| 422 | |
| 423 | LogicalResult ForallOp::verify() { |
| 424 | if (!getPatterns().empty() && getNoPattern()) |
| 425 | return emitOpError() << "patterns and the no_pattern attribute must not be " |
| 426 | "specified at the same time" ; |
| 427 | |
| 428 | return success(); |
| 429 | } |
| 430 | |
| 431 | LogicalResult ForallOp::verifyRegions() { |
| 432 | return verifyQuantifierRegions(*this); |
| 433 | } |
| 434 | |
| 435 | void ForallOp::build( |
| 436 | OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes, |
| 437 | function_ref<Value(OpBuilder &, Location, ValueRange)> bodyBuilder, |
| 438 | std::optional<ArrayRef<StringRef>> boundVarNames, |
| 439 | function_ref<ValueRange(OpBuilder &, Location, ValueRange)> patternBuilder, |
| 440 | uint32_t weight, bool noPattern) { |
| 441 | buildQuantifier<Properties>(odsBuilder, odsState, boundVarTypes, bodyBuilder, |
| 442 | boundVarNames, patternBuilder, weight, noPattern); |
| 443 | } |
| 444 | |
| 445 | //===----------------------------------------------------------------------===// |
| 446 | // ExistsOp |
| 447 | //===----------------------------------------------------------------------===// |
| 448 | |
| 449 | LogicalResult ExistsOp::verify() { |
| 450 | if (!getPatterns().empty() && getNoPattern()) |
| 451 | return emitOpError() << "patterns and the no_pattern attribute must not be " |
| 452 | "specified at the same time" ; |
| 453 | |
| 454 | return success(); |
| 455 | } |
| 456 | |
| 457 | LogicalResult ExistsOp::verifyRegions() { |
| 458 | return verifyQuantifierRegions(*this); |
| 459 | } |
| 460 | |
| 461 | void ExistsOp::build( |
| 462 | OpBuilder &odsBuilder, OperationState &odsState, TypeRange boundVarTypes, |
| 463 | function_ref<Value(OpBuilder &, Location, ValueRange)> bodyBuilder, |
| 464 | std::optional<ArrayRef<StringRef>> boundVarNames, |
| 465 | function_ref<ValueRange(OpBuilder &, Location, ValueRange)> patternBuilder, |
| 466 | uint32_t weight, bool noPattern) { |
| 467 | buildQuantifier<Properties>(odsBuilder, odsState, boundVarTypes, bodyBuilder, |
| 468 | boundVarNames, patternBuilder, weight, noPattern); |
| 469 | } |
| 470 | |
| 471 | #define GET_OP_CLASSES |
| 472 | #include "mlir/Dialect/SMT/IR/SMT.cpp.inc" |
| 473 | |