1 | //===- SMT.cpp - C interface for the SMT dialect --------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir-c/Dialect/SMT.h" |
10 | #include "mlir/CAPI/Registration.h" |
11 | #include "mlir/Dialect/SMT/IR/SMTAttributes.h" |
12 | #include "mlir/Dialect/SMT/IR/SMTDialect.h" |
13 | #include "mlir/Dialect/SMT/IR/SMTTypes.h" |
14 | |
15 | using namespace mlir; |
16 | using namespace smt; |
17 | |
18 | //===----------------------------------------------------------------------===// |
19 | // Dialect API. |
20 | //===----------------------------------------------------------------------===// |
21 | |
22 | MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SMT, smt, mlir::smt::SMTDialect) |
23 | |
24 | //===----------------------------------------------------------------------===// |
25 | // Type API. |
26 | //===----------------------------------------------------------------------===// |
27 | |
28 | bool mlirSMTTypeIsAnyNonFuncSMTValueType(MlirType type) { |
29 | return isAnyNonFuncSMTValueType(type: unwrap(c: type)); |
30 | } |
31 | |
32 | bool mlirSMTTypeIsAnySMTValueType(MlirType type) { |
33 | return isAnySMTValueType(type: unwrap(c: type)); |
34 | } |
35 | |
36 | bool mlirSMTTypeIsAArray(MlirType type) { return isa<ArrayType>(unwrap(type)); } |
37 | |
38 | MlirType mlirSMTTypeGetArray(MlirContext ctx, MlirType domainType, |
39 | MlirType rangeType) { |
40 | return wrap( |
41 | ArrayType::get(unwrap(ctx), unwrap(domainType), unwrap(rangeType))); |
42 | } |
43 | |
44 | bool mlirSMTTypeIsABitVector(MlirType type) { |
45 | return isa<BitVectorType>(unwrap(type)); |
46 | } |
47 | |
48 | MlirType mlirSMTTypeGetBitVector(MlirContext ctx, int32_t width) { |
49 | return wrap(BitVectorType::get(unwrap(ctx), width)); |
50 | } |
51 | |
52 | bool mlirSMTTypeIsABool(MlirType type) { return isa<BoolType>(unwrap(type)); } |
53 | |
54 | MlirType mlirSMTTypeGetBool(MlirContext ctx) { |
55 | return wrap(BoolType::get(unwrap(ctx))); |
56 | } |
57 | |
58 | bool mlirSMTTypeIsAInt(MlirType type) { return isa<IntType>(unwrap(type)); } |
59 | |
60 | MlirType mlirSMTTypeGetInt(MlirContext ctx) { |
61 | return wrap(IntType::get(unwrap(ctx))); |
62 | } |
63 | |
64 | bool mlirSMTTypeIsASMTFunc(MlirType type) { |
65 | return isa<SMTFuncType>(unwrap(type)); |
66 | } |
67 | |
68 | MlirType mlirSMTTypeGetSMTFunc(MlirContext ctx, size_t numberOfDomainTypes, |
69 | const MlirType *domainTypes, |
70 | MlirType rangeType) { |
71 | SmallVector<Type> domainTypesVec; |
72 | domainTypesVec.reserve(N: numberOfDomainTypes); |
73 | |
74 | for (size_t i = 0; i < numberOfDomainTypes; i++) |
75 | domainTypesVec.push_back(Elt: unwrap(c: domainTypes[i])); |
76 | |
77 | return wrap(SMTFuncType::get(unwrap(ctx), domainTypesVec, unwrap(rangeType))); |
78 | } |
79 | |
80 | bool mlirSMTTypeIsASort(MlirType type) { return isa<SortType>(unwrap(type)); } |
81 | |
82 | MlirType mlirSMTTypeGetSort(MlirContext ctx, MlirIdentifier identifier, |
83 | size_t numberOfSortParams, |
84 | const MlirType *sortParams) { |
85 | SmallVector<Type> sortParamsVec; |
86 | sortParamsVec.reserve(N: numberOfSortParams); |
87 | |
88 | for (size_t i = 0; i < numberOfSortParams; i++) |
89 | sortParamsVec.push_back(Elt: unwrap(c: sortParams[i])); |
90 | |
91 | return wrap(SortType::get(unwrap(ctx), unwrap(identifier), sortParamsVec)); |
92 | } |
93 | |
94 | //===----------------------------------------------------------------------===// |
95 | // Attribute API. |
96 | //===----------------------------------------------------------------------===// |
97 | |
98 | bool mlirSMTAttrCheckBVCmpPredicate(MlirContext ctx, MlirStringRef str) { |
99 | return symbolizeBVCmpPredicate(unwrap(ref: str)).has_value(); |
100 | } |
101 | |
102 | bool mlirSMTAttrCheckIntPredicate(MlirContext ctx, MlirStringRef str) { |
103 | return symbolizeIntPredicate(unwrap(ref: str)).has_value(); |
104 | } |
105 | |
106 | bool mlirSMTAttrIsASMTAttribute(MlirAttribute attr) { |
107 | return isa<BitVectorAttr, BVCmpPredicateAttr, IntPredicateAttr>(unwrap(attr)); |
108 | } |
109 | |
110 | MlirAttribute mlirSMTAttrGetBitVector(MlirContext ctx, uint64_t value, |
111 | unsigned width) { |
112 | return wrap(BitVectorAttr::get(unwrap(ctx), value, width)); |
113 | } |
114 | |
115 | MlirAttribute mlirSMTAttrGetBVCmpPredicate(MlirContext ctx, MlirStringRef str) { |
116 | auto predicate = symbolizeBVCmpPredicate(unwrap(ref: str)); |
117 | assert(predicate.has_value() && "invalid predicate" ); |
118 | |
119 | return wrap(BVCmpPredicateAttr::get(unwrap(ctx), predicate.value())); |
120 | } |
121 | |
122 | MlirAttribute mlirSMTAttrGetIntPredicate(MlirContext ctx, MlirStringRef str) { |
123 | auto predicate = symbolizeIntPredicate(unwrap(ref: str)); |
124 | assert(predicate.has_value() && "invalid predicate" ); |
125 | |
126 | return wrap(IntPredicateAttr::get(unwrap(ctx), predicate.value())); |
127 | } |
128 | |