1 | //===- sparse_tensor.c - Test of sparse_tensor APIs -----------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM |
4 | // Exceptions. |
5 | // See https://llvm.org/LICENSE.txt for license information. |
6 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
7 | // |
8 | //===----------------------------------------------------------------------===// |
9 | |
10 | // RUN: mlir-capi-sparse-tensor-test 2>&1 | FileCheck %s |
11 | |
12 | #include "mlir-c/Dialect/SparseTensor.h" |
13 | #include "mlir-c/IR.h" |
14 | #include "mlir-c/RegisterEverything.h" |
15 | |
16 | #include <assert.h> |
17 | #include <inttypes.h> |
18 | #include <math.h> |
19 | #include <stdio.h> |
20 | #include <stdlib.h> |
21 | #include <string.h> |
22 | |
23 | // CHECK-LABEL: testRoundtripEncoding() |
24 | static int testRoundtripEncoding(MlirContext ctx) { |
25 | fprintf(stderr, format: "testRoundtripEncoding()\n" ); |
26 | // clang-format off |
27 | const char *originalAsm = |
28 | "#sparse_tensor.encoding<{ " |
29 | "map = [s0](d0, d1) -> (s0 : dense, d0 : compressed, d1 : compressed), " |
30 | "posWidth = 32, crdWidth = 64 }>" ; |
31 | // clang-format on |
32 | MlirAttribute originalAttr = |
33 | mlirAttributeParseGet(context: ctx, attr: mlirStringRefCreateFromCString(str: originalAsm)); |
34 | // CHECK: isa: 1 |
35 | fprintf(stderr, format: "isa: %d\n" , |
36 | mlirAttributeIsASparseTensorEncodingAttr(attr: originalAttr)); |
37 | MlirAffineMap dimToLvl = |
38 | mlirSparseTensorEncodingAttrGetDimToLvl(attr: originalAttr); |
39 | // CHECK: (d0, d1)[s0] -> (s0, d0, d1) |
40 | mlirAffineMapDump(affineMap: dimToLvl); |
41 | // CHECK: level_type: 65536 |
42 | // CHECK: level_type: 262144 |
43 | // CHECK: level_type: 262144 |
44 | MlirAffineMap lvlToDim = |
45 | mlirSparseTensorEncodingAttrGetLvlToDim(attr: originalAttr); |
46 | int lvlRank = mlirSparseTensorEncodingGetLvlRank(attr: originalAttr); |
47 | MlirSparseTensorLevelType *lvlTypes = |
48 | malloc(size: sizeof(MlirSparseTensorLevelType) * lvlRank); |
49 | for (int l = 0; l < lvlRank; ++l) { |
50 | lvlTypes[l] = mlirSparseTensorEncodingAttrGetLvlType(attr: originalAttr, lvl: l); |
51 | fprintf(stderr, format: "level_type: %" PRIu64 "\n" , lvlTypes[l]); |
52 | } |
53 | // CHECK: posWidth: 32 |
54 | int posWidth = mlirSparseTensorEncodingAttrGetPosWidth(attr: originalAttr); |
55 | fprintf(stderr, format: "posWidth: %d\n" , posWidth); |
56 | // CHECK: crdWidth: 64 |
57 | int crdWidth = mlirSparseTensorEncodingAttrGetCrdWidth(attr: originalAttr); |
58 | fprintf(stderr, format: "crdWidth: %d\n" , crdWidth); |
59 | MlirAttribute newAttr = mlirSparseTensorEncodingAttrGet( |
60 | ctx, lvlRank, lvlTypes, dimToLvl, lvlTodim: lvlToDim, posWidth, crdWidth); |
61 | mlirAttributeDump(attr: newAttr); // For debugging filecheck output. |
62 | // CHECK: equal: 1 |
63 | fprintf(stderr, format: "equal: %d\n" , mlirAttributeEqual(a1: originalAttr, a2: newAttr)); |
64 | free(ptr: lvlTypes); |
65 | return 0; |
66 | } |
67 | |
68 | int main(void) { |
69 | MlirContext ctx = mlirContextCreate(); |
70 | mlirDialectHandleRegisterDialect(mlirGetDialectHandle__sparse_tensor__(), |
71 | ctx); |
72 | if (testRoundtripEncoding(ctx)) |
73 | return 1; |
74 | |
75 | mlirContextDestroy(context: ctx); |
76 | return 0; |
77 | } |
78 | |