1//===- RunnerUtils.h - Utils for debugging MLIR execution -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file declares basic classes and functions to debug structured MLIR
10// types at runtime. Entities in this file may not be compatible with targets
11// without a C++ runtime. These may be progressively migrated to CRunnerUtils.h
12// over time.
13//
14//===----------------------------------------------------------------------===//
15
16#ifndef MLIR_EXECUTIONENGINE_RUNNERUTILS_H
17#define MLIR_EXECUTIONENGINE_RUNNERUTILS_H
18
19#ifdef _WIN32
20#ifndef MLIR_RUNNERUTILS_EXPORT
21#ifdef mlir_runner_utils_EXPORTS
22// We are building this library
23#define MLIR_RUNNERUTILS_EXPORT __declspec(dllexport)
24#else
25// We are using this library
26#define MLIR_RUNNERUTILS_EXPORT __declspec(dllimport)
27#endif // mlir_runner_utils_EXPORTS
28#endif // MLIR_RUNNERUTILS_EXPORT
29#else
30// Non-windows: use visibility attributes.
31#define MLIR_RUNNERUTILS_EXPORT __attribute__((visibility("default")))
32#endif // _WIN32
33
34#include <assert.h>
35#include <cmath>
36#include <complex>
37#include <iomanip>
38#include <iostream>
39
40#include "mlir/ExecutionEngine/CRunnerUtils.h"
41#include "mlir/ExecutionEngine/Float16bits.h"
42
43template <typename T, typename StreamType>
44void printMemRefMetaData(StreamType &os, const DynamicMemRefType<T> &v) {
45 // Make the printed pointer format platform independent by casting it to an
46 // integer and manually formatting it to a hex with prefix as tests expect.
47 os << "base@ = " << std::hex << std::showbase
48 << reinterpret_cast<std::intptr_t>(v.data) << std::dec << std::noshowbase
49 << " rank = " << v.rank << " offset = " << v.offset;
50 auto print = [&](const int64_t *ptr) {
51 if (v.rank == 0)
52 return;
53 os << ptr[0];
54 for (int64_t i = 1; i < v.rank; ++i)
55 os << ", " << ptr[i];
56 };
57 os << " sizes = [";
58 print(v.sizes);
59 os << "] strides = [";
60 print(v.strides);
61 os << "]";
62}
63
64template <typename StreamType, typename T, int N>
65void printMemRefMetaData(StreamType &os, StridedMemRefType<T, N> &v) {
66 static_assert(N >= 0, "Expected N > 0");
67 os << "MemRef ";
68 printMemRefMetaData(os, DynamicMemRefType<T>(v));
69}
70
71template <typename StreamType, typename T>
72void printUnrankedMemRefMetaData(StreamType &os, UnrankedMemRefType<T> &v) {
73 os << "Unranked MemRef ";
74 printMemRefMetaData(os, DynamicMemRefType<T>(v));
75}
76
77////////////////////////////////////////////////////////////////////////////////
78// Templated instantiation follows.
79////////////////////////////////////////////////////////////////////////////////
80namespace impl {
81using index_type = uint64_t;
82using complex64 = std::complex<double>;
83using complex32 = std::complex<float>;
84
85template <typename T, int M, int... Dims>
86std::ostream &operator<<(std::ostream &os, const Vector<T, M, Dims...> &v);
87
88template <int... Dims>
89struct StaticSizeMult {
90 static constexpr int value = 1;
91};
92
93template <int N, int... Dims>
94struct StaticSizeMult<N, Dims...> {
95 static constexpr int value = N * StaticSizeMult<Dims...>::value;
96};
97
98static inline void printSpace(std::ostream &os, int count) {
99 for (int i = 0; i < count; ++i) {
100 os << ' ';
101 }
102}
103
104template <typename T, int M, int... Dims>
105struct VectorDataPrinter {
106 static void print(std::ostream &os, const Vector<T, M, Dims...> &val);
107};
108
109template <typename T, int M, int... Dims>
110void VectorDataPrinter<T, M, Dims...>::print(std::ostream &os,
111 const Vector<T, M, Dims...> &val) {
112 static_assert(M > 0, "0 dimensioned tensor");
113 static_assert(sizeof(val) == M * StaticSizeMult<Dims...>::value * sizeof(T),
114 "Incorrect vector size!");
115 // First
116 os << "(" << val[0];
117 if (M > 1)
118 os << ", ";
119 if (sizeof...(Dims) > 1)
120 os << "\n";
121 // Kernel
122 for (unsigned i = 1; i + 1 < M; ++i) {
123 printSpace(os, count: 2 * sizeof...(Dims));
124 os << val[i] << ", ";
125 if (sizeof...(Dims) > 1)
126 os << "\n";
127 }
128 // Last
129 if (M > 1) {
130 printSpace(os, count: sizeof...(Dims));
131 os << val[M - 1];
132 }
133 os << ")";
134}
135
136template <typename T, int M, int... Dims>
137std::ostream &operator<<(std::ostream &os, const Vector<T, M, Dims...> &v) {
138 VectorDataPrinter<T, M, Dims...>::print(os, v);
139 return os;
140}
141
142template <typename T>
143struct MemRefDataPrinter {
144 static void print(std::ostream &os, T *base, int64_t dim, int64_t rank,
145 int64_t offset, const int64_t *sizes,
146 const int64_t *strides);
147 static void printFirst(std::ostream &os, T *base, int64_t dim, int64_t rank,
148 int64_t offset, const int64_t *sizes,
149 const int64_t *strides);
150 static void printLast(std::ostream &os, T *base, int64_t dim, int64_t rank,
151 int64_t offset, const int64_t *sizes,
152 const int64_t *strides);
153};
154
155template <typename T>
156void MemRefDataPrinter<T>::printFirst(std::ostream &os, T *base, int64_t dim,
157 int64_t rank, int64_t offset,
158 const int64_t *sizes,
159 const int64_t *strides) {
160 os << "[";
161 print(os, base, dim: dim - 1, rank, offset, sizes: sizes + 1, strides: strides + 1);
162 // If single element, close square bracket and return early.
163 if (sizes[0] <= 1) {
164 os << "]";
165 return;
166 }
167 os << ", ";
168 if (dim > 1)
169 os << "\n";
170}
171
172template <typename T>
173void MemRefDataPrinter<T>::print(std::ostream &os, T *base, int64_t dim,
174 int64_t rank, int64_t offset,
175 const int64_t *sizes, const int64_t *strides) {
176 if (dim == 0) {
177 os << base[offset];
178 return;
179 }
180 printFirst(os, base, dim, rank, offset, sizes, strides);
181 for (unsigned i = 1; i + 1 < sizes[0]; ++i) {
182 printSpace(os, count: rank - dim + 1);
183 print(os, base, dim: dim - 1, rank, offset: offset + i * strides[0], sizes: sizes + 1,
184 strides: strides + 1);
185 os << ", ";
186 if (dim > 1)
187 os << "\n";
188 }
189 if (sizes[0] <= 1)
190 return;
191 printLast(os, base, dim, rank, offset, sizes, strides);
192}
193
194template <typename T>
195void MemRefDataPrinter<T>::printLast(std::ostream &os, T *base, int64_t dim,
196 int64_t rank, int64_t offset,
197 const int64_t *sizes,
198 const int64_t *strides) {
199 printSpace(os, count: rank - dim + 1);
200 print(os, base, dim: dim - 1, rank, offset: offset + (sizes[0] - 1) * (*strides),
201 sizes: sizes + 1, strides: strides + 1);
202 os << "]";
203}
204
205template <typename T, int N>
206void printMemRefShape(StridedMemRefType<T, N> &m) {
207 std::cout << "Memref ";
208 printMemRefMetaData(std::cout, DynamicMemRefType<T>(m));
209}
210
211template <typename T>
212void printMemRefShape(UnrankedMemRefType<T> &m) {
213 std::cout << "Unranked Memref ";
214 printMemRefMetaData(std::cout, DynamicMemRefType<T>(m));
215}
216
217template <typename T>
218void printMemRef(const DynamicMemRefType<T> &m) {
219 printMemRefMetaData(std::cout, m);
220 std::cout << " data = \n";
221 if (m.rank == 0)
222 std::cout << "[";
223 MemRefDataPrinter<T>::print(std::cout, m.data, m.rank, m.rank, m.offset,
224 m.sizes, m.strides);
225 if (m.rank == 0)
226 std::cout << "]";
227 std::cout << '\n' << std::flush;
228}
229
230template <typename T, int N>
231void printMemRef(StridedMemRefType<T, N> &m) {
232 std::cout << "Memref ";
233 printMemRef(DynamicMemRefType<T>(m));
234}
235
236template <typename T>
237void printMemRef(UnrankedMemRefType<T> &m) {
238 std::cout << "Unranked Memref ";
239 printMemRef(DynamicMemRefType<T>(m));
240}
241
242/// Verify the result of two computations are equivalent up to a small
243/// numerical error and return the number of errors.
244template <typename T>
245struct MemRefDataVerifier {
246 /// Maximum number of errors printed by the verifier.
247 static constexpr int printLimit = 10;
248
249 /// Verify the relative difference of the values is smaller than epsilon.
250 static bool verifyRelErrorSmallerThan(T actual, T expected, T epsilon);
251
252 /// Verify the values are equivalent (integers) or are close (floating-point).
253 static bool verifyElem(T actual, T expected);
254
255 /// Verify the data element-by-element and return the number of errors.
256 static int64_t verify(std::ostream &os, T *actualBasePtr, T *expectedBasePtr,
257 int64_t dim, int64_t offset, const int64_t *sizes,
258 const int64_t *strides, int64_t &printCounter);
259};
260
261template <typename T>
262bool MemRefDataVerifier<T>::verifyRelErrorSmallerThan(T actual, T expected,
263 T epsilon) {
264 // Return an error if one of the values is infinite or NaN.
265 if (!std::isfinite(actual) || !std::isfinite(expected))
266 return false;
267 // Return true if the relative error is smaller than epsilon.
268 T delta = std::abs(actual - expected);
269 return (delta <= epsilon * std::abs(expected));
270}
271
272template <typename T>
273bool MemRefDataVerifier<T>::verifyElem(T actual, T expected) {
274 return actual == expected;
275}
276
277template <>
278inline bool MemRefDataVerifier<double>::verifyElem(double actual,
279 double expected) {
280 return verifyRelErrorSmallerThan(actual, expected, epsilon: 1e-12);
281}
282
283template <>
284inline bool MemRefDataVerifier<float>::verifyElem(float actual,
285 float expected) {
286 return verifyRelErrorSmallerThan(actual, expected, epsilon: 1e-6f);
287}
288
289template <typename T>
290int64_t MemRefDataVerifier<T>::verify(std::ostream &os, T *actualBasePtr,
291 T *expectedBasePtr, int64_t dim,
292 int64_t offset, const int64_t *sizes,
293 const int64_t *strides,
294 int64_t &printCounter) {
295 int64_t errors = 0;
296 // Verify the elements at the current offset.
297 if (dim == 0) {
298 if (!verifyElem(actual: actualBasePtr[offset], expected: expectedBasePtr[offset])) {
299 if (printCounter < printLimit) {
300 os << actualBasePtr[offset] << " != " << expectedBasePtr[offset]
301 << " offset = " << offset << "\n";
302 printCounter++;
303 }
304 errors++;
305 }
306 } else {
307 // Iterate the current dimension and verify recursively.
308 for (int64_t i = 0; i < sizes[0]; ++i) {
309 errors +=
310 verify(os, actualBasePtr, expectedBasePtr, dim: dim - 1,
311 offset: offset + i * strides[0], sizes: sizes + 1, strides: strides + 1, printCounter);
312 }
313 }
314 return errors;
315}
316
317/// Verify the equivalence of two dynamic memrefs and return the number of
318/// errors or -1 if the shape of the memrefs do not match.
319template <typename T>
320int64_t verifyMemRef(const DynamicMemRefType<T> &actual,
321 const DynamicMemRefType<T> &expected) {
322 // Check if the memref shapes match.
323 for (int64_t i = 0; i < actual.rank; ++i) {
324 if (expected.rank != actual.rank || actual.offset != expected.offset ||
325 actual.sizes[i] != expected.sizes[i] ||
326 actual.strides[i] != expected.strides[i]) {
327 printMemRefMetaData(std::cerr, actual);
328 printMemRefMetaData(std::cerr, expected);
329 return -1;
330 }
331 }
332 // Return the number of errors.
333 int64_t printCounter = 0;
334 return MemRefDataVerifier<T>::verify(std::cerr, actual.data, expected.data,
335 actual.rank, actual.offset, actual.sizes,
336 actual.strides, printCounter);
337}
338
339/// Verify the equivalence of two unranked memrefs and return the number of
340/// errors or -1 if the shape of the memrefs do not match.
341template <typename T>
342int64_t verifyMemRef(UnrankedMemRefType<T> &actual,
343 UnrankedMemRefType<T> &expected) {
344 return verifyMemRef(DynamicMemRefType<T>(actual),
345 DynamicMemRefType<T>(expected));
346}
347
348} // namespace impl
349
350////////////////////////////////////////////////////////////////////////////////
351// Currently exposed C API.
352////////////////////////////////////////////////////////////////////////////////
353extern "C" MLIR_RUNNERUTILS_EXPORT void
354_mlir_ciface_printMemrefShapeI8(UnrankedMemRefType<int8_t> *m);
355extern "C" MLIR_RUNNERUTILS_EXPORT void
356_mlir_ciface_printMemrefShapeI32(UnrankedMemRefType<int32_t> *m);
357extern "C" MLIR_RUNNERUTILS_EXPORT void
358_mlir_ciface_printMemrefShapeI64(UnrankedMemRefType<int64_t> *m);
359extern "C" MLIR_RUNNERUTILS_EXPORT void
360_mlir_ciface_printMemrefShapeF32(UnrankedMemRefType<float> *m);
361extern "C" MLIR_RUNNERUTILS_EXPORT void
362_mlir_ciface_printMemrefShapeF64(UnrankedMemRefType<double> *m);
363extern "C" MLIR_RUNNERUTILS_EXPORT void
364_mlir_ciface_printMemrefShapeInd(UnrankedMemRefType<impl::index_type> *m);
365extern "C" MLIR_RUNNERUTILS_EXPORT void
366_mlir_ciface_printMemrefShapeC32(UnrankedMemRefType<impl::complex32> *m);
367extern "C" MLIR_RUNNERUTILS_EXPORT void
368_mlir_ciface_printMemrefShapeC64(UnrankedMemRefType<impl::complex64> *m);
369
370extern "C" MLIR_RUNNERUTILS_EXPORT void
371_mlir_ciface_printMemrefI8(UnrankedMemRefType<int8_t> *m);
372extern "C" MLIR_RUNNERUTILS_EXPORT void
373_mlir_ciface_printMemrefI16(UnrankedMemRefType<int16_t> *m);
374extern "C" MLIR_RUNNERUTILS_EXPORT void
375_mlir_ciface_printMemrefI32(UnrankedMemRefType<int32_t> *m);
376extern "C" MLIR_RUNNERUTILS_EXPORT void
377_mlir_ciface_printMemrefI64(UnrankedMemRefType<int64_t> *m);
378extern "C" MLIR_RUNNERUTILS_EXPORT void
379_mlir_ciface_printMemrefF16(UnrankedMemRefType<f16> *m);
380extern "C" MLIR_RUNNERUTILS_EXPORT void
381_mlir_ciface_printMemrefBF16(UnrankedMemRefType<bf16> *m);
382extern "C" MLIR_RUNNERUTILS_EXPORT void
383_mlir_ciface_printMemrefF32(UnrankedMemRefType<float> *m);
384extern "C" MLIR_RUNNERUTILS_EXPORT void
385_mlir_ciface_printMemrefF64(UnrankedMemRefType<double> *m);
386extern "C" MLIR_RUNNERUTILS_EXPORT void
387_mlir_ciface_printMemrefInd(UnrankedMemRefType<impl::index_type> *m);
388extern "C" MLIR_RUNNERUTILS_EXPORT void
389_mlir_ciface_printMemrefC32(UnrankedMemRefType<impl::complex32> *m);
390extern "C" MLIR_RUNNERUTILS_EXPORT void
391_mlir_ciface_printMemrefC64(UnrankedMemRefType<impl::complex64> *m);
392
393extern "C" MLIR_RUNNERUTILS_EXPORT int64_t _mlir_ciface_nanoTime();
394
395extern "C" MLIR_RUNNERUTILS_EXPORT void printMemrefI32(int64_t rank, void *ptr);
396extern "C" MLIR_RUNNERUTILS_EXPORT void printMemrefI64(int64_t rank, void *ptr);
397extern "C" MLIR_RUNNERUTILS_EXPORT void printMemrefF32(int64_t rank, void *ptr);
398extern "C" MLIR_RUNNERUTILS_EXPORT void printMemrefF64(int64_t rank, void *ptr);
399extern "C" MLIR_RUNNERUTILS_EXPORT void printMemrefInd(int64_t rank, void *ptr);
400extern "C" MLIR_RUNNERUTILS_EXPORT void printMemrefC32(int64_t rank, void *ptr);
401extern "C" MLIR_RUNNERUTILS_EXPORT void printMemrefC64(int64_t rank, void *ptr);
402
403extern "C" MLIR_RUNNERUTILS_EXPORT void
404_mlir_ciface_printMemref0dF32(StridedMemRefType<float, 0> *m);
405extern "C" MLIR_RUNNERUTILS_EXPORT void
406_mlir_ciface_printMemref1dF32(StridedMemRefType<float, 1> *m);
407extern "C" MLIR_RUNNERUTILS_EXPORT void
408_mlir_ciface_printMemref2dF32(StridedMemRefType<float, 2> *m);
409extern "C" MLIR_RUNNERUTILS_EXPORT void
410_mlir_ciface_printMemref3dF32(StridedMemRefType<float, 3> *m);
411extern "C" MLIR_RUNNERUTILS_EXPORT void
412_mlir_ciface_printMemref4dF32(StridedMemRefType<float, 4> *m);
413
414extern "C" MLIR_RUNNERUTILS_EXPORT void
415_mlir_ciface_printMemref1dI8(StridedMemRefType<int8_t, 1> *m);
416extern "C" MLIR_RUNNERUTILS_EXPORT void
417_mlir_ciface_printMemref1dI32(StridedMemRefType<int32_t, 1> *m);
418extern "C" MLIR_RUNNERUTILS_EXPORT void
419_mlir_ciface_printMemref1dI64(StridedMemRefType<int64_t, 1> *m);
420extern "C" MLIR_RUNNERUTILS_EXPORT void
421_mlir_ciface_printMemref1dF64(StridedMemRefType<double, 1> *m);
422extern "C" MLIR_RUNNERUTILS_EXPORT void
423_mlir_ciface_printMemref1dInd(StridedMemRefType<impl::index_type, 1> *m);
424extern "C" MLIR_RUNNERUTILS_EXPORT void
425_mlir_ciface_printMemref1dC32(StridedMemRefType<impl::complex32, 1> *m);
426extern "C" MLIR_RUNNERUTILS_EXPORT void
427_mlir_ciface_printMemref1dC64(StridedMemRefType<impl::complex64, 1> *m);
428
429extern "C" MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_printMemrefVector4x4xf32(
430 StridedMemRefType<Vector2D<4, 4, float>, 2> *m);
431
432extern "C" MLIR_RUNNERUTILS_EXPORT int64_t _mlir_ciface_verifyMemRefI8(
433 UnrankedMemRefType<int8_t> *actual, UnrankedMemRefType<int8_t> *expected);
434extern "C" MLIR_RUNNERUTILS_EXPORT int64_t _mlir_ciface_verifyMemRefI16(
435 UnrankedMemRefType<int16_t> *actual, UnrankedMemRefType<int16_t> *expected);
436extern "C" MLIR_RUNNERUTILS_EXPORT int64_t _mlir_ciface_verifyMemRefI32(
437 UnrankedMemRefType<int32_t> *actual, UnrankedMemRefType<int32_t> *expected);
438extern "C" MLIR_RUNNERUTILS_EXPORT int64_t _mlir_ciface_verifyMemRefI64(
439 UnrankedMemRefType<int64_t> *actual, UnrankedMemRefType<int64_t> *expected);
440extern "C" MLIR_RUNNERUTILS_EXPORT int64_t _mlir_ciface_verifyMemRefBF16(
441 UnrankedMemRefType<bf16> *actual, UnrankedMemRefType<bf16> *expected);
442extern "C" MLIR_RUNNERUTILS_EXPORT int64_t _mlir_ciface_verifyMemRefF16(
443 UnrankedMemRefType<f16> *actual, UnrankedMemRefType<f16> *expected);
444extern "C" MLIR_RUNNERUTILS_EXPORT int64_t _mlir_ciface_verifyMemRefF32(
445 UnrankedMemRefType<float> *actual, UnrankedMemRefType<float> *expected);
446extern "C" MLIR_RUNNERUTILS_EXPORT int64_t _mlir_ciface_verifyMemRefF64(
447 UnrankedMemRefType<double> *actual, UnrankedMemRefType<double> *expected);
448extern "C" MLIR_RUNNERUTILS_EXPORT int64_t
449_mlir_ciface_verifyMemRefInd(UnrankedMemRefType<impl::index_type> *actual,
450 UnrankedMemRefType<impl::index_type> *expected);
451extern "C" MLIR_RUNNERUTILS_EXPORT int64_t
452_mlir_ciface_verifyMemRefC32(UnrankedMemRefType<impl::complex32> *actual,
453 UnrankedMemRefType<impl::complex32> *expected);
454extern "C" MLIR_RUNNERUTILS_EXPORT int64_t
455_mlir_ciface_verifyMemRefC64(UnrankedMemRefType<impl::complex64> *actual,
456 UnrankedMemRefType<impl::complex64> *expected);
457
458extern "C" MLIR_RUNNERUTILS_EXPORT int64_t verifyMemRefI32(int64_t rank,
459 void *actualPtr,
460 void *expectedPtr);
461extern "C" MLIR_RUNNERUTILS_EXPORT int64_t verifyMemRefF32(int64_t rank,
462 void *actualPtr,
463 void *expectedPtr);
464extern "C" MLIR_RUNNERUTILS_EXPORT int64_t verifyMemRefF64(int64_t rank,
465 void *actualPtr,
466 void *expectedPtr);
467extern "C" MLIR_RUNNERUTILS_EXPORT int64_t verifyMemRefInd(int64_t rank,
468 void *actualPtr,
469 void *expectedPtr);
470extern "C" MLIR_RUNNERUTILS_EXPORT int64_t verifyMemRefC32(int64_t rank,
471 void *actualPtr,
472 void *expectedPtr);
473extern "C" MLIR_RUNNERUTILS_EXPORT int64_t verifyMemRefC64(int64_t rank,
474 void *actualPtr,
475 void *expectedPtr);
476
477#endif // MLIR_EXECUTIONENGINE_RUNNERUTILS_H
478

source code of mlir/include/mlir/ExecutionEngine/RunnerUtils.h