1 | // Split the MLIR string: this will produce %t/input.mlir |
2 | // RUN: split-file %s %t |
3 | |
4 | // Compile the MLIR file to LLVM: |
5 | // RUN: mlir-opt %t/input.mlir \ |
6 | // RUN: -lower-affine -convert-scf-to-cf -finalize-memref-to-llvm \ |
7 | // RUN: -convert-func-to-llvm -reconcile-unrealized-casts \ |
8 | // RUN: | mlir-translate --mlir-to-llvmir -o %t.ll |
9 | |
10 | // Generate an object file for the MLIR code |
11 | // RUN: llc %t.ll -o %t.o -filetype=obj |
12 | |
13 | // Compile the current C file and link it to the MLIR code: |
14 | // RUN: %host_cc %s %t.o -o %t.exe |
15 | |
16 | // Exec |
17 | // RUN: %t.exe | FileCheck %s |
18 | |
19 | /* MLIR_BEGIN |
20 | //--- input.mlir |
21 | // Performs: arg0[i, j] = arg0[i, j] + arg1[i, j] |
22 | func.func private @add_memref(%arg0: memref<?x?xf64>, %arg1: memref<?x?xf64>) -> i64 |
23 | attributes {llvm.emit_c_interface} { |
24 | %c0 = arith.constant 0 : index |
25 | %c1 = arith.constant 1 : index |
26 | %dimI = memref.dim %arg0, %c0 : memref<?x?xf64> |
27 | %dimJ = memref.dim %arg0, %c1 : memref<?x?xf64> |
28 | affine.for %i = 0 to %dimI { |
29 | affine.for %j = 0 to %dimJ { |
30 | %load0 = memref.load %arg0[%i, %j] : memref<?x?xf64> |
31 | %load1 = memref.load %arg1[%i, %j] : memref<?x?xf64> |
32 | %add = arith.addf %load0, %load1 : f64 |
33 | affine.store %add, %arg0[%i, %j] : memref<?x?xf64> |
34 | } |
35 | } |
36 | %c42 = arith.constant 42 : i64 |
37 | return %c42 : i64 |
38 | } |
39 | |
40 | //--- end_input.mlir |
41 | |
42 | MLIR_END */ |
43 | |
44 | #include <stdint.h> |
45 | #include <stdio.h> |
46 | |
47 | // Define the API for the MLIR function, see |
48 | // https://mlir.llvm.org/docs/TargetLLVMIR/#calling-conventions for details. |
49 | // |
50 | // The function takes two 2D memref, the signature in MLIR LLVM dialect will be: |
51 | // llvm.func @add_memref( |
52 | // // First Memref (%arg0) |
53 | // %allocated_ptr0: !llvm.ptr<f64>, %aligned_ptr0: !llvm.ptr<f64>, |
54 | // %offset0: i64, %size0_d0: i64, %size0_d1: i64, %stride0_d0: i64, |
55 | // %stride0_d1: i64, |
56 | // // Second Memref (%arg1) |
57 | // %allocated_ptr1: !llvm.ptr<f64>, %aligned_ptr1: !llvm.ptr<f64>, |
58 | // %offset1: i64, %size1_d0: i64, %size1_d1: i64, %stride1_d0: i64, |
59 | // %stride1_d1: i64, |
60 | // |
61 | long long add_memref(double *allocated_ptr0, double *aligned_ptr0, |
62 | intptr_t offset0, intptr_t size0_d0, intptr_t size0_d1, |
63 | intptr_t stride0_d0, intptr_t stride0_d1, |
64 | // Second Memref (%arg1) |
65 | double *allocated_ptr1, double *aligned_ptr1, |
66 | intptr_t offset1, intptr_t size1_d0, intptr_t size1_d1, |
67 | intptr_t stride1_d0, intptr_t stride1_d1); |
68 | |
69 | // The llvm.emit_c_interface will also trigger emission of another wrapper: |
70 | // llvm.func @_mlir_ciface_add_memref( |
71 | // %arg0: !llvm.ptr<struct<(ptr<f64>, ptr<f64>, i64, |
72 | // array<2 x i64>, array<2 x i64>)>>, |
73 | // %arg1: !llvm.ptr<struct<(ptr<f64>, ptr<f64>, i64, |
74 | // array<2 x i64>, array<2 x i64>)>>) |
75 | // -> i64 |
76 | typedef struct { |
77 | double *allocated; |
78 | double *aligned; |
79 | intptr_t offset; |
80 | intptr_t size[2]; |
81 | intptr_t stride[2]; |
82 | } memref_2d_descriptor; |
83 | long long _mlir_ciface_add_memref(memref_2d_descriptor *arg0, |
84 | memref_2d_descriptor *arg1); |
85 | |
86 | #define N 4 |
87 | #define M 8 |
88 | double arg0[N][M]; |
89 | double arg1[N][M]; |
90 | |
91 | void dump() { |
92 | for (int i = 0; i < N; i++) { |
93 | printf(format: "[" ); |
94 | for (int j = 0; j < M; j++) |
95 | printf(format: "%d,\t" , (int)arg0[i][j]); |
96 | printf(format: "] [" ); |
97 | for (int j = 0; j < M; j++) |
98 | printf(format: "%d,\t" , (int)arg1[i][j]); |
99 | printf(format: "]\n" ); |
100 | } |
101 | } |
102 | |
103 | int main() { |
104 | int count = 0; |
105 | for (int i = 0; i < N; i++) { |
106 | for (int j = 0; j < M; j++) { |
107 | arg0[i][j] = count++; |
108 | arg1[i][j] = count++; |
109 | } |
110 | } |
111 | printf(format: "Before:\n" ); |
112 | dump(); |
113 | // clang-format off |
114 | // CHECK-LABEL: Before: |
115 | // CHECK: [0, 2, 4, 6, 8, 10, 12, 14, ] [1, 3, 5, 7, 9, 11, 13, 15, ] |
116 | // CHECK: [16, 18, 20, 22, 24, 26, 28, 30, ] [17, 19, 21, 23, 25, 27, 29, 31, ] |
117 | // CHECK: [32, 34, 36, 38, 40, 42, 44, 46, ] [33, 35, 37, 39, 41, 43, 45, 47, ] |
118 | // CHECK: [48, 50, 52, 54, 56, 58, 60, 62, ] [49, 51, 53, 55, 57, 59, 61, 63, ] |
119 | // clang-format on |
120 | |
121 | // Call into MLIR. |
122 | long long result = add_memref(allocated_ptr0: (double *)arg0, aligned_ptr0: (double *)arg0, offset0: 0, N, M, M, stride0_d1: 0, |
123 | // |
124 | allocated_ptr1: (double *)arg1, aligned_ptr1: (double *)arg1, offset1: 0, N, M, M, stride1_d1: 0); |
125 | |
126 | // CHECK-LABEL: Result: |
127 | // CHECK: 42 |
128 | printf(format: "Result: %d\n" , (int)result); |
129 | |
130 | printf(format: "After:\n" ); |
131 | dump(); |
132 | |
133 | // clang-format off |
134 | // CHECK-LABEL: After: |
135 | // CHECK: [1, 5, 9, 13, 17, 21, 25, 29, ] [1, 3, 5, 7, 9, 11, 13, 15, ] |
136 | // CHECK: [33, 37, 41, 45, 49, 53, 57, 61, ] [17, 19, 21, 23, 25, 27, 29, 31, ] |
137 | // CHECK: [65, 69, 73, 77, 81, 85, 89, 93, ] [33, 35, 37, 39, 41, 43, 45, 47, ] |
138 | // CHECK: [97, 101, 105, 109, 113, 117, 121, 125, ] [49, 51, 53, 55, 57, 59, 61, 63, ] |
139 | // clang-format on |
140 | |
141 | // Reset the input and re-apply the same function use the C API wrapper. |
142 | count = 0; |
143 | for (int i = 0; i < N; i++) { |
144 | for (int j = 0; j < M; j++) { |
145 | arg0[i][j] = count++; |
146 | arg1[i][j] = count++; |
147 | } |
148 | } |
149 | |
150 | // Call into MLIR. |
151 | memref_2d_descriptor arg0_descriptor = { |
152 | .allocated: (double *)arg0, .aligned: (double *)arg0, .offset: 0, N, M, M, 0}; |
153 | memref_2d_descriptor arg1_descriptor = { |
154 | .allocated: (double *)arg1, .aligned: (double *)arg1, .offset: 0, N, M, M, 0}; |
155 | result = _mlir_ciface_add_memref(arg0: &arg0_descriptor, arg1: &arg1_descriptor); |
156 | |
157 | // CHECK-LABEL: Result2: |
158 | // CHECK: 42 |
159 | printf(format: "Result2: %d\n" , (int)result); |
160 | |
161 | printf(format: "After2:\n" ); |
162 | dump(); |
163 | |
164 | // clang-format off |
165 | // CHECK-LABEL: After2: |
166 | // CHECK: [1, 5, 9, 13, 17, 21, 25, 29, ] [1, 3, 5, 7, 9, 11, 13, 15, ] |
167 | // CHECK: [33, 37, 41, 45, 49, 53, 57, 61, ] [17, 19, 21, 23, 25, 27, 29, 31, ] |
168 | // CHECK: [65, 69, 73, 77, 81, 85, 89, 93, ] [33, 35, 37, 39, 41, 43, 45, 47, ] |
169 | // CHECK: [97, 101, 105, 109, 113, 117, 121, 125, ] [49, 51, 53, 55, 57, 59, 61, 63, ] |
170 | // clang-format on |
171 | |
172 | return 0; |
173 | } |
174 | |