1//===- BuiltinTypes.cpp - C Interface to MLIR Builtin Types ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir-c/BuiltinTypes.h"
10#include "mlir-c/AffineMap.h"
11#include "mlir-c/IR.h"
12#include "mlir-c/Support.h"
13#include "mlir/CAPI/AffineMap.h"
14#include "mlir/CAPI/IR.h"
15#include "mlir/CAPI/Support.h"
16#include "mlir/IR/AffineMap.h"
17#include "mlir/IR/BuiltinTypes.h"
18#include "mlir/IR/Types.h"
19
20#include <algorithm>
21
22using namespace mlir;
23
24//===----------------------------------------------------------------------===//
25// Integer types.
26//===----------------------------------------------------------------------===//
27
28MlirTypeID mlirIntegerTypeGetTypeID() { return wrap(cpp: IntegerType::getTypeID()); }
29
30bool mlirTypeIsAInteger(MlirType type) {
31 return llvm::isa<IntegerType>(Val: unwrap(c: type));
32}
33
34MlirType mlirIntegerTypeGet(MlirContext ctx, unsigned bitwidth) {
35 return wrap(cpp: IntegerType::get(context: unwrap(c: ctx), width: bitwidth));
36}
37
38MlirType mlirIntegerTypeSignedGet(MlirContext ctx, unsigned bitwidth) {
39 return wrap(cpp: IntegerType::get(context: unwrap(c: ctx), width: bitwidth, signedness: IntegerType::Signed));
40}
41
42MlirType mlirIntegerTypeUnsignedGet(MlirContext ctx, unsigned bitwidth) {
43 return wrap(cpp: IntegerType::get(context: unwrap(c: ctx), width: bitwidth, signedness: IntegerType::Unsigned));
44}
45
46unsigned mlirIntegerTypeGetWidth(MlirType type) {
47 return llvm::cast<IntegerType>(Val: unwrap(c: type)).getWidth();
48}
49
50bool mlirIntegerTypeIsSignless(MlirType type) {
51 return llvm::cast<IntegerType>(Val: unwrap(c: type)).isSignless();
52}
53
54bool mlirIntegerTypeIsSigned(MlirType type) {
55 return llvm::cast<IntegerType>(Val: unwrap(c: type)).isSigned();
56}
57
58bool mlirIntegerTypeIsUnsigned(MlirType type) {
59 return llvm::cast<IntegerType>(Val: unwrap(c: type)).isUnsigned();
60}
61
62//===----------------------------------------------------------------------===//
63// Index type.
64//===----------------------------------------------------------------------===//
65
66MlirTypeID mlirIndexTypeGetTypeID() { return wrap(cpp: IndexType::getTypeID()); }
67
68bool mlirTypeIsAIndex(MlirType type) {
69 return llvm::isa<IndexType>(Val: unwrap(c: type));
70}
71
72MlirType mlirIndexTypeGet(MlirContext ctx) {
73 return wrap(cpp: IndexType::get(context: unwrap(c: ctx)));
74}
75
76//===----------------------------------------------------------------------===//
77// Floating-point types.
78//===----------------------------------------------------------------------===//
79
80bool mlirTypeIsAFloat(MlirType type) {
81 return llvm::isa<FloatType>(Val: unwrap(c: type));
82}
83
84unsigned mlirFloatTypeGetWidth(MlirType type) {
85 return llvm::cast<FloatType>(Val: unwrap(c: type)).getWidth();
86}
87
88MlirTypeID mlirFloat4E2M1FNTypeGetTypeID() {
89 return wrap(cpp: Float4E2M1FNType::getTypeID());
90}
91
92bool mlirTypeIsAFloat4E2M1FN(MlirType type) {
93 return llvm::isa<Float4E2M1FNType>(Val: unwrap(c: type));
94}
95
96MlirType mlirFloat4E2M1FNTypeGet(MlirContext ctx) {
97 return wrap(cpp: Float4E2M1FNType::get(ctx: unwrap(c: ctx)));
98}
99
100MlirTypeID mlirFloat6E2M3FNTypeGetTypeID() {
101 return wrap(cpp: Float6E2M3FNType::getTypeID());
102}
103
104bool mlirTypeIsAFloat6E2M3FN(MlirType type) {
105 return llvm::isa<Float6E2M3FNType>(Val: unwrap(c: type));
106}
107
108MlirType mlirFloat6E2M3FNTypeGet(MlirContext ctx) {
109 return wrap(cpp: Float6E2M3FNType::get(ctx: unwrap(c: ctx)));
110}
111
112MlirTypeID mlirFloat6E3M2FNTypeGetTypeID() {
113 return wrap(cpp: Float6E3M2FNType::getTypeID());
114}
115
116bool mlirTypeIsAFloat6E3M2FN(MlirType type) {
117 return llvm::isa<Float6E3M2FNType>(Val: unwrap(c: type));
118}
119
120MlirType mlirFloat6E3M2FNTypeGet(MlirContext ctx) {
121 return wrap(cpp: Float6E3M2FNType::get(ctx: unwrap(c: ctx)));
122}
123
124MlirTypeID mlirFloat8E5M2TypeGetTypeID() {
125 return wrap(cpp: Float8E5M2Type::getTypeID());
126}
127
128bool mlirTypeIsAFloat8E5M2(MlirType type) {
129 return llvm::isa<Float8E5M2Type>(Val: unwrap(c: type));
130}
131
132MlirType mlirFloat8E5M2TypeGet(MlirContext ctx) {
133 return wrap(cpp: Float8E5M2Type::get(ctx: unwrap(c: ctx)));
134}
135
136MlirTypeID mlirFloat8E4M3TypeGetTypeID() {
137 return wrap(cpp: Float8E4M3Type::getTypeID());
138}
139
140bool mlirTypeIsAFloat8E4M3(MlirType type) {
141 return llvm::isa<Float8E4M3Type>(Val: unwrap(c: type));
142}
143
144MlirType mlirFloat8E4M3TypeGet(MlirContext ctx) {
145 return wrap(cpp: Float8E4M3Type::get(ctx: unwrap(c: ctx)));
146}
147
148MlirTypeID mlirFloat8E4M3FNTypeGetTypeID() {
149 return wrap(cpp: Float8E4M3FNType::getTypeID());
150}
151
152bool mlirTypeIsAFloat8E4M3FN(MlirType type) {
153 return llvm::isa<Float8E4M3FNType>(Val: unwrap(c: type));
154}
155
156MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx) {
157 return wrap(cpp: Float8E4M3FNType::get(ctx: unwrap(c: ctx)));
158}
159
160MlirTypeID mlirFloat8E5M2FNUZTypeGetTypeID() {
161 return wrap(cpp: Float8E5M2FNUZType::getTypeID());
162}
163
164bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type) {
165 return llvm::isa<Float8E5M2FNUZType>(Val: unwrap(c: type));
166}
167
168MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx) {
169 return wrap(cpp: Float8E5M2FNUZType::get(ctx: unwrap(c: ctx)));
170}
171
172MlirTypeID mlirFloat8E4M3FNUZTypeGetTypeID() {
173 return wrap(cpp: Float8E4M3FNUZType::getTypeID());
174}
175
176bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type) {
177 return llvm::isa<Float8E4M3FNUZType>(Val: unwrap(c: type));
178}
179
180MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx) {
181 return wrap(cpp: Float8E4M3FNUZType::get(ctx: unwrap(c: ctx)));
182}
183
184MlirTypeID mlirFloat8E4M3B11FNUZTypeGetTypeID() {
185 return wrap(cpp: Float8E4M3B11FNUZType::getTypeID());
186}
187
188bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type) {
189 return llvm::isa<Float8E4M3B11FNUZType>(Val: unwrap(c: type));
190}
191
192MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx) {
193 return wrap(cpp: Float8E4M3B11FNUZType::get(ctx: unwrap(c: ctx)));
194}
195
196MlirTypeID mlirFloat8E3M4TypeGetTypeID() {
197 return wrap(cpp: Float8E3M4Type::getTypeID());
198}
199
200bool mlirTypeIsAFloat8E3M4(MlirType type) {
201 return llvm::isa<Float8E3M4Type>(Val: unwrap(c: type));
202}
203
204MlirType mlirFloat8E3M4TypeGet(MlirContext ctx) {
205 return wrap(cpp: Float8E3M4Type::get(ctx: unwrap(c: ctx)));
206}
207
208MlirTypeID mlirFloat8E8M0FNUTypeGetTypeID() {
209 return wrap(cpp: Float8E8M0FNUType::getTypeID());
210}
211
212bool mlirTypeIsAFloat8E8M0FNU(MlirType type) {
213 return llvm::isa<Float8E8M0FNUType>(Val: unwrap(c: type));
214}
215
216MlirType mlirFloat8E8M0FNUTypeGet(MlirContext ctx) {
217 return wrap(cpp: Float8E8M0FNUType::get(ctx: unwrap(c: ctx)));
218}
219
220MlirTypeID mlirBFloat16TypeGetTypeID() {
221 return wrap(cpp: BFloat16Type::getTypeID());
222}
223
224bool mlirTypeIsABF16(MlirType type) {
225 return llvm::isa<BFloat16Type>(Val: unwrap(c: type));
226}
227
228MlirType mlirBF16TypeGet(MlirContext ctx) {
229 return wrap(cpp: BFloat16Type::get(context: unwrap(c: ctx)));
230}
231
232MlirTypeID mlirFloat16TypeGetTypeID() { return wrap(cpp: Float16Type::getTypeID()); }
233
234bool mlirTypeIsAF16(MlirType type) {
235 return llvm::isa<Float16Type>(Val: unwrap(c: type));
236}
237
238MlirType mlirF16TypeGet(MlirContext ctx) {
239 return wrap(cpp: Float16Type::get(context: unwrap(c: ctx)));
240}
241
242MlirTypeID mlirFloatTF32TypeGetTypeID() {
243 return wrap(cpp: FloatTF32Type::getTypeID());
244}
245
246bool mlirTypeIsATF32(MlirType type) {
247 return llvm::isa<FloatTF32Type>(Val: unwrap(c: type));
248}
249
250MlirType mlirTF32TypeGet(MlirContext ctx) {
251 return wrap(cpp: FloatTF32Type::get(context: unwrap(c: ctx)));
252}
253
254MlirTypeID mlirFloat32TypeGetTypeID() { return wrap(cpp: Float32Type::getTypeID()); }
255
256bool mlirTypeIsAF32(MlirType type) {
257 return llvm::isa<Float32Type>(Val: unwrap(c: type));
258}
259
260MlirType mlirF32TypeGet(MlirContext ctx) {
261 return wrap(cpp: Float32Type::get(context: unwrap(c: ctx)));
262}
263
264MlirTypeID mlirFloat64TypeGetTypeID() { return wrap(cpp: Float64Type::getTypeID()); }
265
266bool mlirTypeIsAF64(MlirType type) {
267 return llvm::isa<Float64Type>(Val: unwrap(c: type));
268}
269
270MlirType mlirF64TypeGet(MlirContext ctx) {
271 return wrap(cpp: Float64Type::get(context: unwrap(c: ctx)));
272}
273
274//===----------------------------------------------------------------------===//
275// None type.
276//===----------------------------------------------------------------------===//
277
278MlirTypeID mlirNoneTypeGetTypeID() { return wrap(cpp: NoneType::getTypeID()); }
279
280bool mlirTypeIsANone(MlirType type) {
281 return llvm::isa<NoneType>(Val: unwrap(c: type));
282}
283
284MlirType mlirNoneTypeGet(MlirContext ctx) {
285 return wrap(cpp: NoneType::get(context: unwrap(c: ctx)));
286}
287
288//===----------------------------------------------------------------------===//
289// Complex type.
290//===----------------------------------------------------------------------===//
291
292MlirTypeID mlirComplexTypeGetTypeID() { return wrap(cpp: ComplexType::getTypeID()); }
293
294bool mlirTypeIsAComplex(MlirType type) {
295 return llvm::isa<ComplexType>(Val: unwrap(c: type));
296}
297
298MlirType mlirComplexTypeGet(MlirType elementType) {
299 return wrap(cpp: ComplexType::get(elementType: unwrap(c: elementType)));
300}
301
302MlirType mlirComplexTypeGetElementType(MlirType type) {
303 return wrap(cpp: llvm::cast<ComplexType>(Val: unwrap(c: type)).getElementType());
304}
305
306//===----------------------------------------------------------------------===//
307// Shaped type.
308//===----------------------------------------------------------------------===//
309
310bool mlirTypeIsAShaped(MlirType type) {
311 return llvm::isa<ShapedType>(Val: unwrap(c: type));
312}
313
314MlirType mlirShapedTypeGetElementType(MlirType type) {
315 return wrap(cpp: llvm::cast<ShapedType>(Val: unwrap(c: type)).getElementType());
316}
317
318bool mlirShapedTypeHasRank(MlirType type) {
319 return llvm::cast<ShapedType>(Val: unwrap(c: type)).hasRank();
320}
321
322int64_t mlirShapedTypeGetRank(MlirType type) {
323 return llvm::cast<ShapedType>(Val: unwrap(c: type)).getRank();
324}
325
326bool mlirShapedTypeHasStaticShape(MlirType type) {
327 return llvm::cast<ShapedType>(Val: unwrap(c: type)).hasStaticShape();
328}
329
330bool mlirShapedTypeIsDynamicDim(MlirType type, intptr_t dim) {
331 return llvm::cast<ShapedType>(Val: unwrap(c: type))
332 .isDynamicDim(idx: static_cast<unsigned>(dim));
333}
334
335bool mlirShapedTypeIsStaticDim(MlirType type, intptr_t dim) {
336 return llvm::cast<ShapedType>(Val: unwrap(c: type))
337 .isStaticDim(idx: static_cast<unsigned>(dim));
338}
339
340int64_t mlirShapedTypeGetDimSize(MlirType type, intptr_t dim) {
341 return llvm::cast<ShapedType>(Val: unwrap(c: type))
342 .getDimSize(idx: static_cast<unsigned>(dim));
343}
344
345int64_t mlirShapedTypeGetDynamicSize() { return ShapedType::kDynamic; }
346
347bool mlirShapedTypeIsDynamicSize(int64_t size) {
348 return ShapedType::isDynamic(dValue: size);
349}
350
351bool mlirShapedTypeIsStaticSize(int64_t size) {
352 return ShapedType::isStatic(dValue: size);
353}
354
355bool mlirShapedTypeIsDynamicStrideOrOffset(int64_t val) {
356 return ShapedType::isDynamic(dValue: val);
357}
358
359bool mlirShapedTypeIsStaticStrideOrOffset(int64_t val) {
360 return ShapedType::isStatic(dValue: val);
361}
362
363int64_t mlirShapedTypeGetDynamicStrideOrOffset() {
364 return ShapedType::kDynamic;
365}
366
367//===----------------------------------------------------------------------===//
368// Vector type.
369//===----------------------------------------------------------------------===//
370
371MlirTypeID mlirVectorTypeGetTypeID() { return wrap(cpp: VectorType::getTypeID()); }
372
373bool mlirTypeIsAVector(MlirType type) {
374 return llvm::isa<VectorType>(Val: unwrap(c: type));
375}
376
377MlirType mlirVectorTypeGet(intptr_t rank, const int64_t *shape,
378 MlirType elementType) {
379 return wrap(cpp: VectorType::get(shape: llvm::ArrayRef(shape, static_cast<size_t>(rank)),
380 elementType: unwrap(c: elementType)));
381}
382
383MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank,
384 const int64_t *shape, MlirType elementType) {
385 return wrap(cpp: VectorType::getChecked(
386 loc: unwrap(c: loc), args: llvm::ArrayRef(shape, static_cast<size_t>(rank)),
387 args: unwrap(c: elementType)));
388}
389
390MlirType mlirVectorTypeGetScalable(intptr_t rank, const int64_t *shape,
391 const bool *scalable, MlirType elementType) {
392 return wrap(cpp: VectorType::get(
393 shape: llvm::ArrayRef(shape, static_cast<size_t>(rank)), elementType: unwrap(c: elementType),
394 scalableDims: llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
395}
396
397MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
398 const int64_t *shape,
399 const bool *scalable,
400 MlirType elementType) {
401 return wrap(cpp: VectorType::getChecked(
402 loc: unwrap(c: loc), args: llvm::ArrayRef(shape, static_cast<size_t>(rank)),
403 args: unwrap(c: elementType),
404 args: llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
405}
406
407bool mlirVectorTypeIsScalable(MlirType type) {
408 return cast<VectorType>(Val: unwrap(c: type)).isScalable();
409}
410
411bool mlirVectorTypeIsDimScalable(MlirType type, intptr_t dim) {
412 return cast<VectorType>(Val: unwrap(c: type)).getScalableDims()[dim];
413}
414
415//===----------------------------------------------------------------------===//
416// Ranked / Unranked tensor type.
417//===----------------------------------------------------------------------===//
418
419bool mlirTypeIsATensor(MlirType type) {
420 return llvm::isa<TensorType>(Val: unwrap(c: type));
421}
422
423MlirTypeID mlirRankedTensorTypeGetTypeID() {
424 return wrap(cpp: RankedTensorType::getTypeID());
425}
426
427bool mlirTypeIsARankedTensor(MlirType type) {
428 return llvm::isa<RankedTensorType>(Val: unwrap(c: type));
429}
430
431MlirTypeID mlirUnrankedTensorTypeGetTypeID() {
432 return wrap(cpp: UnrankedTensorType::getTypeID());
433}
434
435bool mlirTypeIsAUnrankedTensor(MlirType type) {
436 return llvm::isa<UnrankedTensorType>(Val: unwrap(c: type));
437}
438
439MlirType mlirRankedTensorTypeGet(intptr_t rank, const int64_t *shape,
440 MlirType elementType, MlirAttribute encoding) {
441 return wrap(
442 cpp: RankedTensorType::get(shape: llvm::ArrayRef(shape, static_cast<size_t>(rank)),
443 elementType: unwrap(c: elementType), encoding: unwrap(c: encoding)));
444}
445
446MlirType mlirRankedTensorTypeGetChecked(MlirLocation loc, intptr_t rank,
447 const int64_t *shape,
448 MlirType elementType,
449 MlirAttribute encoding) {
450 return wrap(cpp: RankedTensorType::getChecked(
451 loc: unwrap(c: loc), args: llvm::ArrayRef(shape, static_cast<size_t>(rank)),
452 args: unwrap(c: elementType), args: unwrap(c: encoding)));
453}
454
455MlirAttribute mlirRankedTensorTypeGetEncoding(MlirType type) {
456 return wrap(cpp: llvm::cast<RankedTensorType>(Val: unwrap(c: type)).getEncoding());
457}
458
459MlirType mlirUnrankedTensorTypeGet(MlirType elementType) {
460 return wrap(cpp: UnrankedTensorType::get(elementType: unwrap(c: elementType)));
461}
462
463MlirType mlirUnrankedTensorTypeGetChecked(MlirLocation loc,
464 MlirType elementType) {
465 return wrap(cpp: UnrankedTensorType::getChecked(loc: unwrap(c: loc), args: unwrap(c: elementType)));
466}
467
468MlirType mlirUnrankedTensorTypeGetElementType(MlirType type) {
469 return wrap(cpp: llvm::cast<UnrankedTensorType>(Val: unwrap(c: type)).getElementType());
470}
471
472//===----------------------------------------------------------------------===//
473// Ranked / Unranked MemRef type.
474//===----------------------------------------------------------------------===//
475
476MlirTypeID mlirMemRefTypeGetTypeID() { return wrap(cpp: MemRefType::getTypeID()); }
477
478bool mlirTypeIsAMemRef(MlirType type) {
479 return llvm::isa<MemRefType>(Val: unwrap(c: type));
480}
481
482MlirType mlirMemRefTypeGet(MlirType elementType, intptr_t rank,
483 const int64_t *shape, MlirAttribute layout,
484 MlirAttribute memorySpace) {
485 return wrap(cpp: MemRefType::get(
486 shape: llvm::ArrayRef(shape, static_cast<size_t>(rank)), elementType: unwrap(c: elementType),
487 layout: mlirAttributeIsNull(attr: layout)
488 ? MemRefLayoutAttrInterface()
489 : llvm::cast<MemRefLayoutAttrInterface>(Val: unwrap(c: layout)),
490 memorySpace: unwrap(c: memorySpace)));
491}
492
493MlirType mlirMemRefTypeGetChecked(MlirLocation loc, MlirType elementType,
494 intptr_t rank, const int64_t *shape,
495 MlirAttribute layout,
496 MlirAttribute memorySpace) {
497 return wrap(cpp: MemRefType::getChecked(
498 loc: unwrap(c: loc), args: llvm::ArrayRef(shape, static_cast<size_t>(rank)),
499 args: unwrap(c: elementType),
500 args: mlirAttributeIsNull(attr: layout)
501 ? MemRefLayoutAttrInterface()
502 : llvm::cast<MemRefLayoutAttrInterface>(Val: unwrap(c: layout)),
503 args: unwrap(c: memorySpace)));
504}
505
506MlirType mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank,
507 const int64_t *shape,
508 MlirAttribute memorySpace) {
509 return wrap(cpp: MemRefType::get(shape: llvm::ArrayRef(shape, static_cast<size_t>(rank)),
510 elementType: unwrap(c: elementType), layout: MemRefLayoutAttrInterface(),
511 memorySpace: unwrap(c: memorySpace)));
512}
513
514MlirType mlirMemRefTypeContiguousGetChecked(MlirLocation loc,
515 MlirType elementType, intptr_t rank,
516 const int64_t *shape,
517 MlirAttribute memorySpace) {
518 return wrap(cpp: MemRefType::getChecked(
519 loc: unwrap(c: loc), args: llvm::ArrayRef(shape, static_cast<size_t>(rank)),
520 args: unwrap(c: elementType), args: MemRefLayoutAttrInterface(), args: unwrap(c: memorySpace)));
521}
522
523MlirAttribute mlirMemRefTypeGetLayout(MlirType type) {
524 return wrap(cpp: llvm::cast<MemRefType>(Val: unwrap(c: type)).getLayout());
525}
526
527MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type) {
528 return wrap(cpp: llvm::cast<MemRefType>(Val: unwrap(c: type)).getLayout().getAffineMap());
529}
530
531MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type) {
532 return wrap(cpp: llvm::cast<MemRefType>(Val: unwrap(c: type)).getMemorySpace());
533}
534
535MlirLogicalResult mlirMemRefTypeGetStridesAndOffset(MlirType type,
536 int64_t *strides,
537 int64_t *offset) {
538 MemRefType memrefType = llvm::cast<MemRefType>(Val: unwrap(c: type));
539 SmallVector<int64_t> strides_;
540 if (failed(Result: memrefType.getStridesAndOffset(strides&: strides_, offset&: *offset)))
541 return mlirLogicalResultFailure();
542
543 (void)std::copy(first: strides_.begin(), last: strides_.end(), result: strides);
544 return mlirLogicalResultSuccess();
545}
546
547MlirTypeID mlirUnrankedMemRefTypeGetTypeID() {
548 return wrap(cpp: UnrankedMemRefType::getTypeID());
549}
550
551bool mlirTypeIsAUnrankedMemRef(MlirType type) {
552 return llvm::isa<UnrankedMemRefType>(Val: unwrap(c: type));
553}
554
555MlirType mlirUnrankedMemRefTypeGet(MlirType elementType,
556 MlirAttribute memorySpace) {
557 return wrap(
558 cpp: UnrankedMemRefType::get(elementType: unwrap(c: elementType), memorySpace: unwrap(c: memorySpace)));
559}
560
561MlirType mlirUnrankedMemRefTypeGetChecked(MlirLocation loc,
562 MlirType elementType,
563 MlirAttribute memorySpace) {
564 return wrap(cpp: UnrankedMemRefType::getChecked(loc: unwrap(c: loc), args: unwrap(c: elementType),
565 args: unwrap(c: memorySpace)));
566}
567
568MlirAttribute mlirUnrankedMemrefGetMemorySpace(MlirType type) {
569 return wrap(cpp: llvm::cast<UnrankedMemRefType>(Val: unwrap(c: type)).getMemorySpace());
570}
571
572//===----------------------------------------------------------------------===//
573// Tuple type.
574//===----------------------------------------------------------------------===//
575
576MlirTypeID mlirTupleTypeGetTypeID() { return wrap(cpp: TupleType::getTypeID()); }
577
578bool mlirTypeIsATuple(MlirType type) {
579 return llvm::isa<TupleType>(Val: unwrap(c: type));
580}
581
582MlirType mlirTupleTypeGet(MlirContext ctx, intptr_t numElements,
583 MlirType const *elements) {
584 SmallVector<Type, 4> types;
585 ArrayRef<Type> typeRef = unwrapList(size: numElements, first: elements, storage&: types);
586 return wrap(cpp: TupleType::get(context: unwrap(c: ctx), elementTypes: typeRef));
587}
588
589intptr_t mlirTupleTypeGetNumTypes(MlirType type) {
590 return llvm::cast<TupleType>(Val: unwrap(c: type)).size();
591}
592
593MlirType mlirTupleTypeGetType(MlirType type, intptr_t pos) {
594 return wrap(
595 cpp: llvm::cast<TupleType>(Val: unwrap(c: type)).getType(index: static_cast<size_t>(pos)));
596}
597
598//===----------------------------------------------------------------------===//
599// Function type.
600//===----------------------------------------------------------------------===//
601
602MlirTypeID mlirFunctionTypeGetTypeID() {
603 return wrap(cpp: FunctionType::getTypeID());
604}
605
606bool mlirTypeIsAFunction(MlirType type) {
607 return llvm::isa<FunctionType>(Val: unwrap(c: type));
608}
609
610MlirType mlirFunctionTypeGet(MlirContext ctx, intptr_t numInputs,
611 MlirType const *inputs, intptr_t numResults,
612 MlirType const *results) {
613 SmallVector<Type, 4> inputsList;
614 SmallVector<Type, 4> resultsList;
615 (void)unwrapList(size: numInputs, first: inputs, storage&: inputsList);
616 (void)unwrapList(size: numResults, first: results, storage&: resultsList);
617 return wrap(cpp: FunctionType::get(context: unwrap(c: ctx), inputs: inputsList, results: resultsList));
618}
619
620intptr_t mlirFunctionTypeGetNumInputs(MlirType type) {
621 return llvm::cast<FunctionType>(Val: unwrap(c: type)).getNumInputs();
622}
623
624intptr_t mlirFunctionTypeGetNumResults(MlirType type) {
625 return llvm::cast<FunctionType>(Val: unwrap(c: type)).getNumResults();
626}
627
628MlirType mlirFunctionTypeGetInput(MlirType type, intptr_t pos) {
629 assert(pos >= 0 && "pos in array must be positive");
630 return wrap(cpp: llvm::cast<FunctionType>(Val: unwrap(c: type))
631 .getInput(i: static_cast<unsigned>(pos)));
632}
633
634MlirType mlirFunctionTypeGetResult(MlirType type, intptr_t pos) {
635 assert(pos >= 0 && "pos in array must be positive");
636 return wrap(cpp: llvm::cast<FunctionType>(Val: unwrap(c: type))
637 .getResult(i: static_cast<unsigned>(pos)));
638}
639
640//===----------------------------------------------------------------------===//
641// Opaque type.
642//===----------------------------------------------------------------------===//
643
644MlirTypeID mlirOpaqueTypeGetTypeID() { return wrap(cpp: OpaqueType::getTypeID()); }
645
646bool mlirTypeIsAOpaque(MlirType type) {
647 return llvm::isa<OpaqueType>(Val: unwrap(c: type));
648}
649
650MlirType mlirOpaqueTypeGet(MlirContext ctx, MlirStringRef dialectNamespace,
651 MlirStringRef typeData) {
652 return wrap(
653 cpp: OpaqueType::get(dialectNamespace: StringAttr::get(context: unwrap(c: ctx), bytes: unwrap(ref: dialectNamespace)),
654 typeData: unwrap(ref: typeData)));
655}
656
657MlirStringRef mlirOpaqueTypeGetDialectNamespace(MlirType type) {
658 return wrap(
659 ref: llvm::cast<OpaqueType>(Val: unwrap(c: type)).getDialectNamespace().strref());
660}
661
662MlirStringRef mlirOpaqueTypeGetData(MlirType type) {
663 return wrap(ref: llvm::cast<OpaqueType>(Val: unwrap(c: type)).getTypeData());
664}
665

source code of mlir/lib/CAPI/IR/BuiltinTypes.cpp