1//===- BuiltinTypes.cpp - C Interface to MLIR Builtin Types ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir-c/BuiltinTypes.h"
10#include "mlir-c/AffineMap.h"
11#include "mlir-c/IR.h"
12#include "mlir-c/Support.h"
13#include "mlir/CAPI/AffineMap.h"
14#include "mlir/CAPI/IR.h"
15#include "mlir/CAPI/Support.h"
16#include "mlir/IR/AffineMap.h"
17#include "mlir/IR/BuiltinTypes.h"
18#include "mlir/IR/Types.h"
19#include "mlir/Support/LogicalResult.h"
20
21#include <algorithm>
22
23using namespace mlir;
24
25//===----------------------------------------------------------------------===//
26// Integer types.
27//===----------------------------------------------------------------------===//
28
29MlirTypeID mlirIntegerTypeGetTypeID() { return wrap(IntegerType::getTypeID()); }
30
31bool mlirTypeIsAInteger(MlirType type) {
32 return llvm::isa<IntegerType>(Val: unwrap(c: type));
33}
34
35MlirType mlirIntegerTypeGet(MlirContext ctx, unsigned bitwidth) {
36 return wrap(IntegerType::get(unwrap(ctx), bitwidth));
37}
38
39MlirType mlirIntegerTypeSignedGet(MlirContext ctx, unsigned bitwidth) {
40 return wrap(IntegerType::get(unwrap(ctx), bitwidth, IntegerType::Signed));
41}
42
43MlirType mlirIntegerTypeUnsignedGet(MlirContext ctx, unsigned bitwidth) {
44 return wrap(IntegerType::get(unwrap(ctx), bitwidth, IntegerType::Unsigned));
45}
46
47unsigned mlirIntegerTypeGetWidth(MlirType type) {
48 return llvm::cast<IntegerType>(unwrap(c: type)).getWidth();
49}
50
51bool mlirIntegerTypeIsSignless(MlirType type) {
52 return llvm::cast<IntegerType>(unwrap(c: type)).isSignless();
53}
54
55bool mlirIntegerTypeIsSigned(MlirType type) {
56 return llvm::cast<IntegerType>(unwrap(c: type)).isSigned();
57}
58
59bool mlirIntegerTypeIsUnsigned(MlirType type) {
60 return llvm::cast<IntegerType>(unwrap(c: type)).isUnsigned();
61}
62
63//===----------------------------------------------------------------------===//
64// Index type.
65//===----------------------------------------------------------------------===//
66
67MlirTypeID mlirIndexTypeGetTypeID() { return wrap(IndexType::getTypeID()); }
68
69bool mlirTypeIsAIndex(MlirType type) {
70 return llvm::isa<IndexType>(Val: unwrap(c: type));
71}
72
73MlirType mlirIndexTypeGet(MlirContext ctx) {
74 return wrap(IndexType::get(unwrap(ctx)));
75}
76
77//===----------------------------------------------------------------------===//
78// Floating-point types.
79//===----------------------------------------------------------------------===//
80
81bool mlirTypeIsAFloat(MlirType type) {
82 return llvm::isa<FloatType>(Val: unwrap(c: type));
83}
84
85unsigned mlirFloatTypeGetWidth(MlirType type) {
86 return llvm::cast<FloatType>(Val: unwrap(c: type)).getWidth();
87}
88
89MlirTypeID mlirFloat8E5M2TypeGetTypeID() {
90 return wrap(Float8E5M2Type::getTypeID());
91}
92
93bool mlirTypeIsAFloat8E5M2(MlirType type) {
94 return unwrap(c: type).isFloat8E5M2();
95}
96
97MlirType mlirFloat8E5M2TypeGet(MlirContext ctx) {
98 return wrap(cpp: FloatType::getFloat8E5M2(ctx: unwrap(c: ctx)));
99}
100
101MlirTypeID mlirFloat8E4M3FNTypeGetTypeID() {
102 return wrap(Float8E4M3FNType::getTypeID());
103}
104
105bool mlirTypeIsAFloat8E4M3FN(MlirType type) {
106 return unwrap(c: type).isFloat8E4M3FN();
107}
108
109MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx) {
110 return wrap(cpp: FloatType::getFloat8E4M3FN(ctx: unwrap(c: ctx)));
111}
112
113MlirTypeID mlirFloat8E5M2FNUZTypeGetTypeID() {
114 return wrap(Float8E5M2FNUZType::getTypeID());
115}
116
117bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type) {
118 return unwrap(c: type).isFloat8E5M2FNUZ();
119}
120
121MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx) {
122 return wrap(cpp: FloatType::getFloat8E5M2FNUZ(ctx: unwrap(c: ctx)));
123}
124
125MlirTypeID mlirFloat8E4M3FNUZTypeGetTypeID() {
126 return wrap(Float8E4M3FNUZType::getTypeID());
127}
128
129bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type) {
130 return unwrap(c: type).isFloat8E4M3FNUZ();
131}
132
133MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx) {
134 return wrap(cpp: FloatType::getFloat8E4M3FNUZ(ctx: unwrap(c: ctx)));
135}
136
137MlirTypeID mlirFloat8E4M3B11FNUZTypeGetTypeID() {
138 return wrap(Float8E4M3B11FNUZType::getTypeID());
139}
140
141bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type) {
142 return unwrap(c: type).isFloat8E4M3B11FNUZ();
143}
144
145MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx) {
146 return wrap(cpp: FloatType::getFloat8E4M3B11FNUZ(ctx: unwrap(c: ctx)));
147}
148
149MlirTypeID mlirBFloat16TypeGetTypeID() {
150 return wrap(BFloat16Type::getTypeID());
151}
152
153bool mlirTypeIsABF16(MlirType type) { return unwrap(c: type).isBF16(); }
154
155MlirType mlirBF16TypeGet(MlirContext ctx) {
156 return wrap(cpp: FloatType::getBF16(ctx: unwrap(c: ctx)));
157}
158
159MlirTypeID mlirFloat16TypeGetTypeID() { return wrap(Float16Type::getTypeID()); }
160
161bool mlirTypeIsAF16(MlirType type) { return unwrap(c: type).isF16(); }
162
163MlirType mlirF16TypeGet(MlirContext ctx) {
164 return wrap(cpp: FloatType::getF16(ctx: unwrap(c: ctx)));
165}
166
167MlirTypeID mlirFloatTF32TypeGetTypeID() {
168 return wrap(FloatTF32Type::getTypeID());
169}
170
171bool mlirTypeIsATF32(MlirType type) { return unwrap(c: type).isTF32(); }
172
173MlirType mlirTF32TypeGet(MlirContext ctx) {
174 return wrap(cpp: FloatType::getTF32(ctx: unwrap(c: ctx)));
175}
176
177MlirTypeID mlirFloat32TypeGetTypeID() { return wrap(Float32Type::getTypeID()); }
178
179bool mlirTypeIsAF32(MlirType type) { return unwrap(c: type).isF32(); }
180
181MlirType mlirF32TypeGet(MlirContext ctx) {
182 return wrap(cpp: FloatType::getF32(ctx: unwrap(c: ctx)));
183}
184
185MlirTypeID mlirFloat64TypeGetTypeID() { return wrap(Float64Type::getTypeID()); }
186
187bool mlirTypeIsAF64(MlirType type) { return unwrap(c: type).isF64(); }
188
189MlirType mlirF64TypeGet(MlirContext ctx) {
190 return wrap(cpp: FloatType::getF64(ctx: unwrap(c: ctx)));
191}
192
193//===----------------------------------------------------------------------===//
194// None type.
195//===----------------------------------------------------------------------===//
196
197MlirTypeID mlirNoneTypeGetTypeID() { return wrap(NoneType::getTypeID()); }
198
199bool mlirTypeIsANone(MlirType type) {
200 return llvm::isa<NoneType>(unwrap(type));
201}
202
203MlirType mlirNoneTypeGet(MlirContext ctx) {
204 return wrap(NoneType::get(unwrap(ctx)));
205}
206
207//===----------------------------------------------------------------------===//
208// Complex type.
209//===----------------------------------------------------------------------===//
210
211MlirTypeID mlirComplexTypeGetTypeID() { return wrap(ComplexType::getTypeID()); }
212
213bool mlirTypeIsAComplex(MlirType type) {
214 return llvm::isa<ComplexType>(unwrap(type));
215}
216
217MlirType mlirComplexTypeGet(MlirType elementType) {
218 return wrap(ComplexType::get(unwrap(elementType)));
219}
220
221MlirType mlirComplexTypeGetElementType(MlirType type) {
222 return wrap(llvm::cast<ComplexType>(unwrap(type)).getElementType());
223}
224
225//===----------------------------------------------------------------------===//
226// Shaped type.
227//===----------------------------------------------------------------------===//
228
229bool mlirTypeIsAShaped(MlirType type) {
230 return llvm::isa<ShapedType>(unwrap(type));
231}
232
233MlirType mlirShapedTypeGetElementType(MlirType type) {
234 return wrap(llvm::cast<ShapedType>(unwrap(type)).getElementType());
235}
236
237bool mlirShapedTypeHasRank(MlirType type) {
238 return llvm::cast<ShapedType>(unwrap(type)).hasRank();
239}
240
241int64_t mlirShapedTypeGetRank(MlirType type) {
242 return llvm::cast<ShapedType>(unwrap(type)).getRank();
243}
244
245bool mlirShapedTypeHasStaticShape(MlirType type) {
246 return llvm::cast<ShapedType>(unwrap(type)).hasStaticShape();
247}
248
249bool mlirShapedTypeIsDynamicDim(MlirType type, intptr_t dim) {
250 return llvm::cast<ShapedType>(unwrap(type))
251 .isDynamicDim(static_cast<unsigned>(dim));
252}
253
254int64_t mlirShapedTypeGetDimSize(MlirType type, intptr_t dim) {
255 return llvm::cast<ShapedType>(unwrap(type))
256 .getDimSize(static_cast<unsigned>(dim));
257}
258
259int64_t mlirShapedTypeGetDynamicSize() { return ShapedType::kDynamic; }
260
261bool mlirShapedTypeIsDynamicSize(int64_t size) {
262 return ShapedType::isDynamic(size);
263}
264
265bool mlirShapedTypeIsDynamicStrideOrOffset(int64_t val) {
266 return ShapedType::isDynamic(val);
267}
268
269int64_t mlirShapedTypeGetDynamicStrideOrOffset() {
270 return ShapedType::kDynamic;
271}
272
273//===----------------------------------------------------------------------===//
274// Vector type.
275//===----------------------------------------------------------------------===//
276
277MlirTypeID mlirVectorTypeGetTypeID() { return wrap(VectorType::getTypeID()); }
278
279bool mlirTypeIsAVector(MlirType type) {
280 return llvm::isa<VectorType>(unwrap(type));
281}
282
283MlirType mlirVectorTypeGet(intptr_t rank, const int64_t *shape,
284 MlirType elementType) {
285 return wrap(VectorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
286 unwrap(elementType)));
287}
288
289MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank,
290 const int64_t *shape, MlirType elementType) {
291 return wrap(VectorType::getChecked(
292 unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
293 unwrap(elementType)));
294}
295
296MlirType mlirVectorTypeGetScalable(intptr_t rank, const int64_t *shape,
297 const bool *scalable, MlirType elementType) {
298 return wrap(VectorType::get(
299 llvm::ArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
300 llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
301}
302
303MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank,
304 const int64_t *shape,
305 const bool *scalable,
306 MlirType elementType) {
307 return wrap(VectorType::getChecked(
308 unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
309 unwrap(elementType),
310 llvm::ArrayRef(scalable, static_cast<size_t>(rank))));
311}
312
313bool mlirVectorTypeIsScalable(MlirType type) {
314 return cast<VectorType>(unwrap(type)).isScalable();
315}
316
317bool mlirVectorTypeIsDimScalable(MlirType type, intptr_t dim) {
318 return cast<VectorType>(unwrap(type)).getScalableDims()[dim];
319}
320
321//===----------------------------------------------------------------------===//
322// Ranked / Unranked tensor type.
323//===----------------------------------------------------------------------===//
324
325bool mlirTypeIsATensor(MlirType type) {
326 return llvm::isa<TensorType>(Val: unwrap(c: type));
327}
328
329MlirTypeID mlirRankedTensorTypeGetTypeID() {
330 return wrap(RankedTensorType::getTypeID());
331}
332
333bool mlirTypeIsARankedTensor(MlirType type) {
334 return llvm::isa<RankedTensorType>(Val: unwrap(c: type));
335}
336
337MlirTypeID mlirUnrankedTensorTypeGetTypeID() {
338 return wrap(UnrankedTensorType::getTypeID());
339}
340
341bool mlirTypeIsAUnrankedTensor(MlirType type) {
342 return llvm::isa<UnrankedTensorType>(unwrap(type));
343}
344
345MlirType mlirRankedTensorTypeGet(intptr_t rank, const int64_t *shape,
346 MlirType elementType, MlirAttribute encoding) {
347 return wrap(
348 RankedTensorType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
349 unwrap(elementType), unwrap(encoding)));
350}
351
352MlirType mlirRankedTensorTypeGetChecked(MlirLocation loc, intptr_t rank,
353 const int64_t *shape,
354 MlirType elementType,
355 MlirAttribute encoding) {
356 return wrap(RankedTensorType::getChecked(
357 unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
358 unwrap(elementType), unwrap(encoding)));
359}
360
361MlirAttribute mlirRankedTensorTypeGetEncoding(MlirType type) {
362 return wrap(llvm::cast<RankedTensorType>(unwrap(c: type)).getEncoding());
363}
364
365MlirType mlirUnrankedTensorTypeGet(MlirType elementType) {
366 return wrap(UnrankedTensorType::get(unwrap(elementType)));
367}
368
369MlirType mlirUnrankedTensorTypeGetChecked(MlirLocation loc,
370 MlirType elementType) {
371 return wrap(UnrankedTensorType::getChecked(unwrap(loc), unwrap(elementType)));
372}
373
374MlirType mlirUnrankedTensorTypeGetElementType(MlirType type) {
375 return wrap(llvm::cast<UnrankedTensorType>(unwrap(type)).getElementType());
376}
377
378//===----------------------------------------------------------------------===//
379// Ranked / Unranked MemRef type.
380//===----------------------------------------------------------------------===//
381
382MlirTypeID mlirMemRefTypeGetTypeID() { return wrap(MemRefType::getTypeID()); }
383
384bool mlirTypeIsAMemRef(MlirType type) {
385 return llvm::isa<MemRefType>(Val: unwrap(c: type));
386}
387
388MlirType mlirMemRefTypeGet(MlirType elementType, intptr_t rank,
389 const int64_t *shape, MlirAttribute layout,
390 MlirAttribute memorySpace) {
391 return wrap(MemRefType::get(
392 llvm::ArrayRef(shape, static_cast<size_t>(rank)), unwrap(elementType),
393 mlirAttributeIsNull(layout)
394 ? MemRefLayoutAttrInterface()
395 : llvm::cast<MemRefLayoutAttrInterface>(unwrap(layout)),
396 unwrap(memorySpace)));
397}
398
399MlirType mlirMemRefTypeGetChecked(MlirLocation loc, MlirType elementType,
400 intptr_t rank, const int64_t *shape,
401 MlirAttribute layout,
402 MlirAttribute memorySpace) {
403 return wrap(MemRefType::getChecked(
404 unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
405 unwrap(elementType),
406 mlirAttributeIsNull(layout)
407 ? MemRefLayoutAttrInterface()
408 : llvm::cast<MemRefLayoutAttrInterface>(unwrap(layout)),
409 unwrap(memorySpace)));
410}
411
412MlirType mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank,
413 const int64_t *shape,
414 MlirAttribute memorySpace) {
415 return wrap(MemRefType::get(llvm::ArrayRef(shape, static_cast<size_t>(rank)),
416 unwrap(elementType), MemRefLayoutAttrInterface(),
417 unwrap(memorySpace)));
418}
419
420MlirType mlirMemRefTypeContiguousGetChecked(MlirLocation loc,
421 MlirType elementType, intptr_t rank,
422 const int64_t *shape,
423 MlirAttribute memorySpace) {
424 return wrap(MemRefType::getChecked(
425 unwrap(loc), llvm::ArrayRef(shape, static_cast<size_t>(rank)),
426 unwrap(elementType), MemRefLayoutAttrInterface(), unwrap(memorySpace)));
427}
428
429MlirAttribute mlirMemRefTypeGetLayout(MlirType type) {
430 return wrap(llvm::cast<MemRefType>(unwrap(c: type)).getLayout());
431}
432
433MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type) {
434 return wrap(llvm::cast<MemRefType>(unwrap(c: type)).getLayout().getAffineMap());
435}
436
437MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type) {
438 return wrap(llvm::cast<MemRefType>(unwrap(c: type)).getMemorySpace());
439}
440
441MlirLogicalResult mlirMemRefTypeGetStridesAndOffset(MlirType type,
442 int64_t *strides,
443 int64_t *offset) {
444 MemRefType memrefType = llvm::cast<MemRefType>(unwrap(type));
445 SmallVector<int64_t> strides_;
446 if (failed(getStridesAndOffset(memrefType, strides_, *offset)))
447 return mlirLogicalResultFailure();
448
449 (void)std::copy(first: strides_.begin(), last: strides_.end(), result: strides);
450 return mlirLogicalResultSuccess();
451}
452
453MlirTypeID mlirUnrankedMemRefTypeGetTypeID() {
454 return wrap(UnrankedMemRefType::getTypeID());
455}
456
457bool mlirTypeIsAUnrankedMemRef(MlirType type) {
458 return llvm::isa<UnrankedMemRefType>(unwrap(type));
459}
460
461MlirType mlirUnrankedMemRefTypeGet(MlirType elementType,
462 MlirAttribute memorySpace) {
463 return wrap(
464 UnrankedMemRefType::get(unwrap(elementType), unwrap(memorySpace)));
465}
466
467MlirType mlirUnrankedMemRefTypeGetChecked(MlirLocation loc,
468 MlirType elementType,
469 MlirAttribute memorySpace) {
470 return wrap(UnrankedMemRefType::getChecked(unwrap(loc), unwrap(elementType),
471 unwrap(memorySpace)));
472}
473
474MlirAttribute mlirUnrankedMemrefGetMemorySpace(MlirType type) {
475 return wrap(llvm::cast<UnrankedMemRefType>(unwrap(type)).getMemorySpace());
476}
477
478//===----------------------------------------------------------------------===//
479// Tuple type.
480//===----------------------------------------------------------------------===//
481
482MlirTypeID mlirTupleTypeGetTypeID() { return wrap(TupleType::getTypeID()); }
483
484bool mlirTypeIsATuple(MlirType type) {
485 return llvm::isa<TupleType>(unwrap(type));
486}
487
488MlirType mlirTupleTypeGet(MlirContext ctx, intptr_t numElements,
489 MlirType const *elements) {
490 SmallVector<Type, 4> types;
491 ArrayRef<Type> typeRef = unwrapList(size: numElements, first: elements, storage&: types);
492 return wrap(TupleType::get(unwrap(ctx), typeRef));
493}
494
495intptr_t mlirTupleTypeGetNumTypes(MlirType type) {
496 return llvm::cast<TupleType>(unwrap(type)).size();
497}
498
499MlirType mlirTupleTypeGetType(MlirType type, intptr_t pos) {
500 return wrap(
501 llvm::cast<TupleType>(unwrap(type)).getType(static_cast<size_t>(pos)));
502}
503
504//===----------------------------------------------------------------------===//
505// Function type.
506//===----------------------------------------------------------------------===//
507
508MlirTypeID mlirFunctionTypeGetTypeID() {
509 return wrap(FunctionType::getTypeID());
510}
511
512bool mlirTypeIsAFunction(MlirType type) {
513 return llvm::isa<FunctionType>(Val: unwrap(c: type));
514}
515
516MlirType mlirFunctionTypeGet(MlirContext ctx, intptr_t numInputs,
517 MlirType const *inputs, intptr_t numResults,
518 MlirType const *results) {
519 SmallVector<Type, 4> inputsList;
520 SmallVector<Type, 4> resultsList;
521 (void)unwrapList(size: numInputs, first: inputs, storage&: inputsList);
522 (void)unwrapList(size: numResults, first: results, storage&: resultsList);
523 return wrap(FunctionType::get(unwrap(ctx), inputsList, resultsList));
524}
525
526intptr_t mlirFunctionTypeGetNumInputs(MlirType type) {
527 return llvm::cast<FunctionType>(unwrap(c: type)).getNumInputs();
528}
529
530intptr_t mlirFunctionTypeGetNumResults(MlirType type) {
531 return llvm::cast<FunctionType>(unwrap(c: type)).getNumResults();
532}
533
534MlirType mlirFunctionTypeGetInput(MlirType type, intptr_t pos) {
535 assert(pos >= 0 && "pos in array must be positive");
536 return wrap(llvm::cast<FunctionType>(unwrap(c: type))
537 .getInput(static_cast<unsigned>(pos)));
538}
539
540MlirType mlirFunctionTypeGetResult(MlirType type, intptr_t pos) {
541 assert(pos >= 0 && "pos in array must be positive");
542 return wrap(llvm::cast<FunctionType>(unwrap(c: type))
543 .getResult(static_cast<unsigned>(pos)));
544}
545
546//===----------------------------------------------------------------------===//
547// Opaque type.
548//===----------------------------------------------------------------------===//
549
550MlirTypeID mlirOpaqueTypeGetTypeID() { return wrap(OpaqueType::getTypeID()); }
551
552bool mlirTypeIsAOpaque(MlirType type) {
553 return llvm::isa<OpaqueType>(unwrap(type));
554}
555
556MlirType mlirOpaqueTypeGet(MlirContext ctx, MlirStringRef dialectNamespace,
557 MlirStringRef typeData) {
558 return wrap(
559 OpaqueType::get(StringAttr::get(unwrap(ctx), unwrap(dialectNamespace)),
560 unwrap(typeData)));
561}
562
563MlirStringRef mlirOpaqueTypeGetDialectNamespace(MlirType type) {
564 return wrap(
565 llvm::cast<OpaqueType>(unwrap(type)).getDialectNamespace().strref());
566}
567
568MlirStringRef mlirOpaqueTypeGetData(MlirType type) {
569 return wrap(llvm::cast<OpaqueType>(unwrap(type)).getTypeData());
570}
571

source code of mlir/lib/CAPI/IR/BuiltinTypes.cpp