1 | //===- BuiltinTypes.cpp - C Interface to MLIR Builtin Types ---------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir-c/BuiltinTypes.h" |
10 | #include "mlir-c/AffineMap.h" |
11 | #include "mlir-c/IR.h" |
12 | #include "mlir-c/Support.h" |
13 | #include "mlir/CAPI/AffineMap.h" |
14 | #include "mlir/CAPI/IR.h" |
15 | #include "mlir/CAPI/Support.h" |
16 | #include "mlir/IR/AffineMap.h" |
17 | #include "mlir/IR/BuiltinTypes.h" |
18 | #include "mlir/IR/Types.h" |
19 | #include "mlir/Support/LogicalResult.h" |
20 | |
21 | #include <algorithm> |
22 | |
23 | using namespace mlir; |
24 | |
25 | //===----------------------------------------------------------------------===// |
26 | // Integer types. |
27 | //===----------------------------------------------------------------------===// |
28 | |
29 | MlirTypeID mlirIntegerTypeGetTypeID() { return wrap(IntegerType::getTypeID()); } |
30 | |
31 | bool mlirTypeIsAInteger(MlirType type) { |
32 | return llvm::isa<IntegerType>(Val: unwrap(c: type)); |
33 | } |
34 | |
35 | MlirType mlirIntegerTypeGet(MlirContext ctx, unsigned bitwidth) { |
36 | return wrap(IntegerType::get(unwrap(ctx), bitwidth)); |
37 | } |
38 | |
39 | MlirType mlirIntegerTypeSignedGet(MlirContext ctx, unsigned bitwidth) { |
40 | return wrap(IntegerType::get(unwrap(ctx), bitwidth, IntegerType::Signed)); |
41 | } |
42 | |
43 | MlirType mlirIntegerTypeUnsignedGet(MlirContext ctx, unsigned bitwidth) { |
44 | return wrap(IntegerType::get(unwrap(ctx), bitwidth, IntegerType::Unsigned)); |
45 | } |
46 | |
47 | unsigned mlirIntegerTypeGetWidth(MlirType type) { |
48 | return llvm::cast<IntegerType>(unwrap(c: type)).getWidth(); |
49 | } |
50 | |
51 | bool mlirIntegerTypeIsSignless(MlirType type) { |
52 | return llvm::cast<IntegerType>(unwrap(c: type)).isSignless(); |
53 | } |
54 | |
55 | bool mlirIntegerTypeIsSigned(MlirType type) { |
56 | return llvm::cast<IntegerType>(unwrap(c: type)).isSigned(); |
57 | } |
58 | |
59 | bool mlirIntegerTypeIsUnsigned(MlirType type) { |
60 | return llvm::cast<IntegerType>(unwrap(c: type)).isUnsigned(); |
61 | } |
62 | |
63 | //===----------------------------------------------------------------------===// |
64 | // Index type. |
65 | //===----------------------------------------------------------------------===// |
66 | |
67 | MlirTypeID mlirIndexTypeGetTypeID() { return wrap(IndexType::getTypeID()); } |
68 | |
69 | bool mlirTypeIsAIndex(MlirType type) { |
70 | return llvm::isa<IndexType>(Val: unwrap(c: type)); |
71 | } |
72 | |
73 | MlirType mlirIndexTypeGet(MlirContext ctx) { |
74 | return wrap(IndexType::get(unwrap(ctx))); |
75 | } |
76 | |
77 | //===----------------------------------------------------------------------===// |
78 | // Floating-point types. |
79 | //===----------------------------------------------------------------------===// |
80 | |
81 | bool mlirTypeIsAFloat(MlirType type) { |
82 | return llvm::isa<FloatType>(Val: unwrap(c: type)); |
83 | } |
84 | |
85 | unsigned mlirFloatTypeGetWidth(MlirType type) { |
86 | return llvm::cast<FloatType>(Val: unwrap(c: type)).getWidth(); |
87 | } |
88 | |
89 | MlirTypeID mlirFloat8E5M2TypeGetTypeID() { |
90 | return wrap(Float8E5M2Type::getTypeID()); |
91 | } |
92 | |
93 | bool mlirTypeIsAFloat8E5M2(MlirType type) { |
94 | return unwrap(c: type).isFloat8E5M2(); |
95 | } |
96 | |
97 | MlirType mlirFloat8E5M2TypeGet(MlirContext ctx) { |
98 | return wrap(cpp: FloatType::getFloat8E5M2(ctx: unwrap(c: ctx))); |
99 | } |
100 | |
101 | MlirTypeID mlirFloat8E4M3FNTypeGetTypeID() { |
102 | return wrap(Float8E4M3FNType::getTypeID()); |
103 | } |
104 | |
105 | bool mlirTypeIsAFloat8E4M3FN(MlirType type) { |
106 | return unwrap(c: type).isFloat8E4M3FN(); |
107 | } |
108 | |
109 | MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx) { |
110 | return wrap(cpp: FloatType::getFloat8E4M3FN(ctx: unwrap(c: ctx))); |
111 | } |
112 | |
113 | MlirTypeID mlirFloat8E5M2FNUZTypeGetTypeID() { |
114 | return wrap(Float8E5M2FNUZType::getTypeID()); |
115 | } |
116 | |
117 | bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type) { |
118 | return unwrap(c: type).isFloat8E5M2FNUZ(); |
119 | } |
120 | |
121 | MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx) { |
122 | return wrap(cpp: FloatType::getFloat8E5M2FNUZ(ctx: unwrap(c: ctx))); |
123 | } |
124 | |
125 | MlirTypeID mlirFloat8E4M3FNUZTypeGetTypeID() { |
126 | return wrap(Float8E4M3FNUZType::getTypeID()); |
127 | } |
128 | |
129 | bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type) { |
130 | return unwrap(c: type).isFloat8E4M3FNUZ(); |
131 | } |
132 | |
133 | MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx) { |
134 | return wrap(cpp: FloatType::getFloat8E4M3FNUZ(ctx: unwrap(c: ctx))); |
135 | } |
136 | |
137 | MlirTypeID mlirFloat8E4M3B11FNUZTypeGetTypeID() { |
138 | return wrap(Float8E4M3B11FNUZType::getTypeID()); |
139 | } |
140 | |
141 | bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type) { |
142 | return unwrap(c: type).isFloat8E4M3B11FNUZ(); |
143 | } |
144 | |
145 | MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx) { |
146 | return wrap(cpp: FloatType::getFloat8E4M3B11FNUZ(ctx: unwrap(c: ctx))); |
147 | } |
148 | |
149 | MlirTypeID mlirBFloat16TypeGetTypeID() { |
150 | return wrap(BFloat16Type::getTypeID()); |
151 | } |
152 | |
153 | bool mlirTypeIsABF16(MlirType type) { return unwrap(c: type).isBF16(); } |
154 | |
155 | MlirType mlirBF16TypeGet(MlirContext ctx) { |
156 | return wrap(cpp: FloatType::getBF16(ctx: unwrap(c: ctx))); |
157 | } |
158 | |
159 | MlirTypeID mlirFloat16TypeGetTypeID() { return wrap(Float16Type::getTypeID()); } |
160 | |
161 | bool mlirTypeIsAF16(MlirType type) { return unwrap(c: type).isF16(); } |
162 | |
163 | MlirType mlirF16TypeGet(MlirContext ctx) { |
164 | return wrap(cpp: FloatType::getF16(ctx: unwrap(c: ctx))); |
165 | } |
166 | |
167 | MlirTypeID mlirFloatTF32TypeGetTypeID() { |
168 | return wrap(FloatTF32Type::getTypeID()); |
169 | } |
170 | |
171 | bool mlirTypeIsATF32(MlirType type) { return unwrap(c: type).isTF32(); } |
172 | |
173 | MlirType mlirTF32TypeGet(MlirContext ctx) { |
174 | return wrap(cpp: FloatType::getTF32(ctx: unwrap(c: ctx))); |
175 | } |
176 | |
177 | MlirTypeID mlirFloat32TypeGetTypeID() { return wrap(Float32Type::getTypeID()); } |
178 | |
179 | bool mlirTypeIsAF32(MlirType type) { return unwrap(c: type).isF32(); } |
180 | |
181 | MlirType mlirF32TypeGet(MlirContext ctx) { |
182 | return wrap(cpp: FloatType::getF32(ctx: unwrap(c: ctx))); |
183 | } |
184 | |
185 | MlirTypeID mlirFloat64TypeGetTypeID() { return wrap(Float64Type::getTypeID()); } |
186 | |
187 | bool mlirTypeIsAF64(MlirType type) { return unwrap(c: type).isF64(); } |
188 | |
189 | MlirType mlirF64TypeGet(MlirContext ctx) { |
190 | return wrap(cpp: FloatType::getF64(ctx: unwrap(c: ctx))); |
191 | } |
192 | |
193 | //===----------------------------------------------------------------------===// |
194 | // None type. |
195 | //===----------------------------------------------------------------------===// |
196 | |
197 | MlirTypeID mlirNoneTypeGetTypeID() { return wrap(NoneType::getTypeID()); } |
198 | |
199 | bool mlirTypeIsANone(MlirType type) { |
200 | return llvm::isa<NoneType>(unwrap(type)); |
201 | } |
202 | |
203 | MlirType mlirNoneTypeGet(MlirContext ctx) { |
204 | return wrap(NoneType::get(unwrap(ctx))); |
205 | } |
206 | |
207 | //===----------------------------------------------------------------------===// |
208 | // Complex type. |
209 | //===----------------------------------------------------------------------===// |
210 | |
211 | MlirTypeID mlirComplexTypeGetTypeID() { return wrap(ComplexType::getTypeID()); } |
212 | |
213 | bool mlirTypeIsAComplex(MlirType type) { |
214 | return llvm::isa<ComplexType>(unwrap(type)); |
215 | } |
216 | |
217 | MlirType mlirComplexTypeGet(MlirType elementType) { |
218 | return wrap(ComplexType::get(unwrap(elementType))); |
219 | } |
220 | |
221 | MlirType mlirComplexTypeGetElementType(MlirType type) { |
222 | return wrap(llvm::cast<ComplexType>(unwrap(type)).getElementType()); |
223 | } |
224 | |
225 | //===----------------------------------------------------------------------===// |
226 | // Shaped type. |
227 | //===----------------------------------------------------------------------===// |
228 | |
229 | bool mlirTypeIsAShaped(MlirType type) { |
230 | return llvm::isa<ShapedType>(unwrap(type)); |
231 | } |
232 | |
233 | MlirType mlirShapedTypeGetElementType(MlirType type) { |
234 | return wrap(llvm::cast<ShapedType>(unwrap(type)).getElementType()); |
235 | } |
236 | |
237 | bool mlirShapedTypeHasRank(MlirType type) { |
238 | return llvm::cast<ShapedType>(unwrap(type)).hasRank(); |
239 | } |
240 | |
241 | int64_t mlirShapedTypeGetRank(MlirType type) { |
242 | return llvm::cast<ShapedType>(unwrap(type)).getRank(); |
243 | } |
244 | |
245 | bool mlirShapedTypeHasStaticShape(MlirType type) { |
246 | return llvm::cast<ShapedType>(unwrap(type)).hasStaticShape(); |
247 | } |
248 | |
249 | bool mlirShapedTypeIsDynamicDim(MlirType type, intptr_t dim) { |
250 | return llvm::cast<ShapedType>(unwrap(type)) |
251 | .isDynamicDim(static_cast<unsigned>(dim)); |
252 | } |
253 | |
254 | int64_t mlirShapedTypeGetDimSize(MlirType type, intptr_t dim) { |
255 | return llvm::cast<ShapedType>(unwrap(type)) |
256 | .getDimSize(static_cast<unsigned>(dim)); |
257 | } |
258 | |
259 | int64_t mlirShapedTypeGetDynamicSize() { return ShapedType::kDynamic; } |
260 | |
261 | bool mlirShapedTypeIsDynamicSize(int64_t size) { |
262 | return ShapedType::isDynamic(size); |
263 | } |
264 | |
265 | bool mlirShapedTypeIsDynamicStrideOrOffset(int64_t val) { |
266 | return ShapedType::isDynamic(val); |
267 | } |
268 | |
269 | int64_t mlirShapedTypeGetDynamicStrideOrOffset() { |
270 | return ShapedType::kDynamic; |
271 | } |
272 | |
273 | //===----------------------------------------------------------------------===// |
274 | // Vector type. |
275 | //===----------------------------------------------------------------------===// |
276 | |
277 | MlirTypeID mlirVectorTypeGetTypeID() { return wrap(VectorType::getTypeID()); } |
278 | |
279 | bool mlirTypeIsAVector(MlirType type) { |
280 | return llvm::isa<VectorType>(unwrap(type)); |
281 | } |
282 | |
283 | MlirType mlirVectorTypeGet(intptr_t rank, const int64_t *shape, |
284 | MlirType elementType) { |
285 | return wrap(VectorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)), |
286 | unwrap(elementType))); |
287 | } |
288 | |
289 | MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank, |
290 | const int64_t *shape, MlirType elementType) { |
291 | return wrap(VectorType::getChecked( |
292 | unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)), |
293 | unwrap(elementType))); |
294 | } |
295 | |
296 | MlirType mlirVectorTypeGetScalable(intptr_t rank, const int64_t *shape, |
297 | const bool *scalable, MlirType elementType) { |
298 | return wrap(VectorType::get( |
299 | llvm::ArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType), |
300 | llvm::ArrayRef(scalable, static_cast<size_t>(rank)))); |
301 | } |
302 | |
303 | MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank, |
304 | const int64_t *shape, |
305 | const bool *scalable, |
306 | MlirType elementType) { |
307 | return wrap(VectorType::getChecked( |
308 | unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)), |
309 | unwrap(elementType), |
310 | llvm::ArrayRef(scalable, static_cast<size_t>(rank)))); |
311 | } |
312 | |
313 | bool mlirVectorTypeIsScalable(MlirType type) { |
314 | return cast<VectorType>(unwrap(type)).isScalable(); |
315 | } |
316 | |
317 | bool mlirVectorTypeIsDimScalable(MlirType type, intptr_t dim) { |
318 | return cast<VectorType>(unwrap(type)).getScalableDims()[dim]; |
319 | } |
320 | |
321 | //===----------------------------------------------------------------------===// |
322 | // Ranked / Unranked tensor type. |
323 | //===----------------------------------------------------------------------===// |
324 | |
325 | bool mlirTypeIsATensor(MlirType type) { |
326 | return llvm::isa<TensorType>(Val: unwrap(c: type)); |
327 | } |
328 | |
329 | MlirTypeID mlirRankedTensorTypeGetTypeID() { |
330 | return wrap(RankedTensorType::getTypeID()); |
331 | } |
332 | |
333 | bool mlirTypeIsARankedTensor(MlirType type) { |
334 | return llvm::isa<RankedTensorType>(Val: unwrap(c: type)); |
335 | } |
336 | |
337 | MlirTypeID mlirUnrankedTensorTypeGetTypeID() { |
338 | return wrap(UnrankedTensorType::getTypeID()); |
339 | } |
340 | |
341 | bool mlirTypeIsAUnrankedTensor(MlirType type) { |
342 | return llvm::isa<UnrankedTensorType>(unwrap(type)); |
343 | } |
344 | |
345 | MlirType mlirRankedTensorTypeGet(intptr_t rank, const int64_t *shape, |
346 | MlirType elementType, MlirAttribute encoding) { |
347 | return wrap( |
348 | RankedTensorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)), |
349 | unwrap(elementType), unwrap(encoding))); |
350 | } |
351 | |
352 | MlirType mlirRankedTensorTypeGetChecked(MlirLocation loc, intptr_t rank, |
353 | const int64_t *shape, |
354 | MlirType elementType, |
355 | MlirAttribute encoding) { |
356 | return wrap(RankedTensorType::getChecked( |
357 | unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)), |
358 | unwrap(elementType), unwrap(encoding))); |
359 | } |
360 | |
361 | MlirAttribute mlirRankedTensorTypeGetEncoding(MlirType type) { |
362 | return wrap(llvm::cast<RankedTensorType>(unwrap(c: type)).getEncoding()); |
363 | } |
364 | |
365 | MlirType mlirUnrankedTensorTypeGet(MlirType elementType) { |
366 | return wrap(UnrankedTensorType::get(unwrap(elementType))); |
367 | } |
368 | |
369 | MlirType mlirUnrankedTensorTypeGetChecked(MlirLocation loc, |
370 | MlirType elementType) { |
371 | return wrap(UnrankedTensorType::getChecked(unwrap(loc), unwrap(elementType))); |
372 | } |
373 | |
374 | MlirType mlirUnrankedTensorTypeGetElementType(MlirType type) { |
375 | return wrap(llvm::cast<UnrankedTensorType>(unwrap(type)).getElementType()); |
376 | } |
377 | |
378 | //===----------------------------------------------------------------------===// |
379 | // Ranked / Unranked MemRef type. |
380 | //===----------------------------------------------------------------------===// |
381 | |
382 | MlirTypeID mlirMemRefTypeGetTypeID() { return wrap(MemRefType::getTypeID()); } |
383 | |
384 | bool mlirTypeIsAMemRef(MlirType type) { |
385 | return llvm::isa<MemRefType>(Val: unwrap(c: type)); |
386 | } |
387 | |
388 | MlirType mlirMemRefTypeGet(MlirType elementType, intptr_t rank, |
389 | const int64_t *shape, MlirAttribute layout, |
390 | MlirAttribute memorySpace) { |
391 | return wrap(MemRefType::get( |
392 | llvm::ArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType), |
393 | mlirAttributeIsNull(layout) |
394 | ? MemRefLayoutAttrInterface() |
395 | : llvm::cast<MemRefLayoutAttrInterface>(unwrap(layout)), |
396 | unwrap(memorySpace))); |
397 | } |
398 | |
399 | MlirType mlirMemRefTypeGetChecked(MlirLocation loc, MlirType elementType, |
400 | intptr_t rank, const int64_t *shape, |
401 | MlirAttribute layout, |
402 | MlirAttribute memorySpace) { |
403 | return wrap(MemRefType::getChecked( |
404 | unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)), |
405 | unwrap(elementType), |
406 | mlirAttributeIsNull(layout) |
407 | ? MemRefLayoutAttrInterface() |
408 | : llvm::cast<MemRefLayoutAttrInterface>(unwrap(layout)), |
409 | unwrap(memorySpace))); |
410 | } |
411 | |
412 | MlirType mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank, |
413 | const int64_t *shape, |
414 | MlirAttribute memorySpace) { |
415 | return wrap(MemRefType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)), |
416 | unwrap(elementType), MemRefLayoutAttrInterface(), |
417 | unwrap(memorySpace))); |
418 | } |
419 | |
420 | MlirType mlirMemRefTypeContiguousGetChecked(MlirLocation loc, |
421 | MlirType elementType, intptr_t rank, |
422 | const int64_t *shape, |
423 | MlirAttribute memorySpace) { |
424 | return wrap(MemRefType::getChecked( |
425 | unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)), |
426 | unwrap(elementType), MemRefLayoutAttrInterface(), unwrap(memorySpace))); |
427 | } |
428 | |
429 | MlirAttribute mlirMemRefTypeGetLayout(MlirType type) { |
430 | return wrap(llvm::cast<MemRefType>(unwrap(c: type)).getLayout()); |
431 | } |
432 | |
433 | MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type) { |
434 | return wrap(llvm::cast<MemRefType>(unwrap(c: type)).getLayout().getAffineMap()); |
435 | } |
436 | |
437 | MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type) { |
438 | return wrap(llvm::cast<MemRefType>(unwrap(c: type)).getMemorySpace()); |
439 | } |
440 | |
441 | MlirLogicalResult mlirMemRefTypeGetStridesAndOffset(MlirType type, |
442 | int64_t *strides, |
443 | int64_t *offset) { |
444 | MemRefType memrefType = llvm::cast<MemRefType>(unwrap(type)); |
445 | SmallVector<int64_t> strides_; |
446 | if (failed(getStridesAndOffset(memrefType, strides_, *offset))) |
447 | return mlirLogicalResultFailure(); |
448 | |
449 | (void)std::copy(first: strides_.begin(), last: strides_.end(), result: strides); |
450 | return mlirLogicalResultSuccess(); |
451 | } |
452 | |
453 | MlirTypeID mlirUnrankedMemRefTypeGetTypeID() { |
454 | return wrap(UnrankedMemRefType::getTypeID()); |
455 | } |
456 | |
457 | bool mlirTypeIsAUnrankedMemRef(MlirType type) { |
458 | return llvm::isa<UnrankedMemRefType>(unwrap(type)); |
459 | } |
460 | |
461 | MlirType mlirUnrankedMemRefTypeGet(MlirType elementType, |
462 | MlirAttribute memorySpace) { |
463 | return wrap( |
464 | UnrankedMemRefType::get(unwrap(elementType), unwrap(memorySpace))); |
465 | } |
466 | |
467 | MlirType mlirUnrankedMemRefTypeGetChecked(MlirLocation loc, |
468 | MlirType elementType, |
469 | MlirAttribute memorySpace) { |
470 | return wrap(UnrankedMemRefType::getChecked(unwrap(loc), unwrap(elementType), |
471 | unwrap(memorySpace))); |
472 | } |
473 | |
474 | MlirAttribute mlirUnrankedMemrefGetMemorySpace(MlirType type) { |
475 | return wrap(llvm::cast<UnrankedMemRefType>(unwrap(type)).getMemorySpace()); |
476 | } |
477 | |
478 | //===----------------------------------------------------------------------===// |
479 | // Tuple type. |
480 | //===----------------------------------------------------------------------===// |
481 | |
482 | MlirTypeID mlirTupleTypeGetTypeID() { return wrap(TupleType::getTypeID()); } |
483 | |
484 | bool mlirTypeIsATuple(MlirType type) { |
485 | return llvm::isa<TupleType>(unwrap(type)); |
486 | } |
487 | |
488 | MlirType mlirTupleTypeGet(MlirContext ctx, intptr_t numElements, |
489 | MlirType const *elements) { |
490 | SmallVector<Type, 4> types; |
491 | ArrayRef<Type> typeRef = unwrapList(size: numElements, first: elements, storage&: types); |
492 | return wrap(TupleType::get(unwrap(ctx), typeRef)); |
493 | } |
494 | |
495 | intptr_t mlirTupleTypeGetNumTypes(MlirType type) { |
496 | return llvm::cast<TupleType>(unwrap(type)).size(); |
497 | } |
498 | |
499 | MlirType mlirTupleTypeGetType(MlirType type, intptr_t pos) { |
500 | return wrap( |
501 | llvm::cast<TupleType>(unwrap(type)).getType(static_cast<size_t>(pos))); |
502 | } |
503 | |
504 | //===----------------------------------------------------------------------===// |
505 | // Function type. |
506 | //===----------------------------------------------------------------------===// |
507 | |
508 | MlirTypeID mlirFunctionTypeGetTypeID() { |
509 | return wrap(FunctionType::getTypeID()); |
510 | } |
511 | |
512 | bool mlirTypeIsAFunction(MlirType type) { |
513 | return llvm::isa<FunctionType>(Val: unwrap(c: type)); |
514 | } |
515 | |
516 | MlirType mlirFunctionTypeGet(MlirContext ctx, intptr_t numInputs, |
517 | MlirType const *inputs, intptr_t numResults, |
518 | MlirType const *results) { |
519 | SmallVector<Type, 4> inputsList; |
520 | SmallVector<Type, 4> resultsList; |
521 | (void)unwrapList(size: numInputs, first: inputs, storage&: inputsList); |
522 | (void)unwrapList(size: numResults, first: results, storage&: resultsList); |
523 | return wrap(FunctionType::get(unwrap(ctx), inputsList, resultsList)); |
524 | } |
525 | |
526 | intptr_t mlirFunctionTypeGetNumInputs(MlirType type) { |
527 | return llvm::cast<FunctionType>(unwrap(c: type)).getNumInputs(); |
528 | } |
529 | |
530 | intptr_t mlirFunctionTypeGetNumResults(MlirType type) { |
531 | return llvm::cast<FunctionType>(unwrap(c: type)).getNumResults(); |
532 | } |
533 | |
534 | MlirType mlirFunctionTypeGetInput(MlirType type, intptr_t pos) { |
535 | assert(pos >= 0 && "pos in array must be positive" ); |
536 | return wrap(llvm::cast<FunctionType>(unwrap(c: type)) |
537 | .getInput(static_cast<unsigned>(pos))); |
538 | } |
539 | |
540 | MlirType mlirFunctionTypeGetResult(MlirType type, intptr_t pos) { |
541 | assert(pos >= 0 && "pos in array must be positive" ); |
542 | return wrap(llvm::cast<FunctionType>(unwrap(c: type)) |
543 | .getResult(static_cast<unsigned>(pos))); |
544 | } |
545 | |
546 | //===----------------------------------------------------------------------===// |
547 | // Opaque type. |
548 | //===----------------------------------------------------------------------===// |
549 | |
550 | MlirTypeID mlirOpaqueTypeGetTypeID() { return wrap(OpaqueType::getTypeID()); } |
551 | |
552 | bool mlirTypeIsAOpaque(MlirType type) { |
553 | return llvm::isa<OpaqueType>(unwrap(type)); |
554 | } |
555 | |
556 | MlirType mlirOpaqueTypeGet(MlirContext ctx, MlirStringRef dialectNamespace, |
557 | MlirStringRef typeData) { |
558 | return wrap( |
559 | OpaqueType::get(StringAttr::get(unwrap(ctx), unwrap(dialectNamespace)), |
560 | unwrap(typeData))); |
561 | } |
562 | |
563 | MlirStringRef mlirOpaqueTypeGetDialectNamespace(MlirType type) { |
564 | return wrap( |
565 | llvm::cast<OpaqueType>(unwrap(type)).getDialectNamespace().strref()); |
566 | } |
567 | |
568 | MlirStringRef mlirOpaqueTypeGetData(MlirType type) { |
569 | return wrap(llvm::cast<OpaqueType>(unwrap(type)).getTypeData()); |
570 | } |
571 | |