diff mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h @ 150:1d019706d866

LLVM10
author anatofuz
date Thu, 13 Feb 2020 15:10:13 +0900
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/mlir/test/mlir-cpu-runner/include/mlir_runner_utils.h	Thu Feb 13 15:10:13 2020 +0900
@@ -0,0 +1,294 @@
+//===- mlir_runner_utils.h - Utils for debugging MLIR CPU execution -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CPU_RUNNER_MLIRUTILS_H_
+#define MLIR_CPU_RUNNER_MLIRUTILS_H_
+
+#include <assert.h>
+#include <cstdint>
+#include <iostream>
+
+#ifdef _WIN32
+#ifndef MLIR_RUNNER_UTILS_EXPORT
+#ifdef mlir_runner_utils_EXPORTS
+/* We are building this library */
+#define MLIR_RUNNER_UTILS_EXPORT __declspec(dllexport)
+#else
+/* We are using this library */
+#define MLIR_RUNNER_UTILS_EXPORT __declspec(dllimport)
+#endif // mlir_runner_utils_EXPORTS
+#endif // MLIR_RUNNER_UTILS_EXPORT
+#else
+#define MLIR_RUNNER_UTILS_EXPORT
+#endif // _WIN32
+
+template <typename T, int N> struct StridedMemRefType;
+template <typename StreamType, typename T, int N>
+void printMemRefMetaData(StreamType &os, StridedMemRefType<T, N> &V);
+
+template <int N> void dropFront(int64_t arr[N], int64_t *res) {
+  for (unsigned i = 1; i < N; ++i)
+    *(res + i - 1) = arr[i];
+}
+
+/// StridedMemRef descriptor type with static rank.
+template <typename T, int N> struct StridedMemRefType {
+  T *basePtr;
+  T *data;
+  int64_t offset;
+  int64_t sizes[N];
+  int64_t strides[N];
+  // This operator[] is extremely slow and only for sugaring purposes.
+  StridedMemRefType<T, N - 1> operator[](int64_t idx) {
+    StridedMemRefType<T, N - 1> res;
+    res.basePtr = basePtr;
+    res.data = data;
+    res.offset = offset + idx * strides[0];
+    dropFront<N>(sizes, res.sizes);
+    dropFront<N>(strides, res.strides);
+    return res;
+  }
+};
+
+/// StridedMemRef descriptor type specialized for rank 1.
+template <typename T> struct StridedMemRefType<T, 1> {
+  T *basePtr;
+  T *data;
+  int64_t offset;
+  int64_t sizes[1];
+  int64_t strides[1];
+  T &operator[](int64_t idx) { return *(data + offset + idx * strides[0]); }
+};
+
+/// StridedMemRef descriptor type specialized for rank 0.
+template <typename T> struct StridedMemRefType<T, 0> {
+  T *basePtr;
+  T *data;
+  int64_t offset;
+};
+
+// Unranked MemRef
+template <typename T> struct UnrankedMemRefType {
+  int64_t rank;
+  void *descriptor;
+};
+
+template <typename StreamType, typename T, int N>
+void printMemRefMetaData(StreamType &os, StridedMemRefType<T, N> &V) {
+  static_assert(N > 0, "Expected N > 0");
+  os << "Memref base@ = " << reinterpret_cast<void *>(V.data) << " rank = " << N
+     << " offset = " << V.offset << " sizes = [" << V.sizes[0];
+  for (unsigned i = 1; i < N; ++i)
+    os << ", " << V.sizes[i];
+  os << "] strides = [" << V.strides[0];
+  for (unsigned i = 1; i < N; ++i)
+    os << ", " << V.strides[i];
+  os << "]";
+}
+
+template <typename StreamType, typename T>
+void printMemRefMetaData(StreamType &os, StridedMemRefType<T, 0> &V) {
+  os << "Memref base@ = " << reinterpret_cast<void *>(V.data) << " rank = 0"
+     << " offset = " << V.offset;
+}
+
+template <typename T, typename StreamType>
+void printUnrankedMemRefMetaData(StreamType &os, UnrankedMemRefType<T> &V) {
+  os << "Unranked Memref rank = " << V.rank << " "
+     << "descriptor@ = " << reinterpret_cast<void *>(V.descriptor) << "\n";
+}
+
+template <typename T, int Dim, int... Dims> struct Vector {
+  Vector<T, Dims...> vector[Dim];
+};
+template <typename T, int Dim> struct Vector<T, Dim> { T vector[Dim]; };
+
+template <int D1, typename T> using Vector1D = Vector<T, D1>;
+template <int D1, int D2, typename T> using Vector2D = Vector<T, D1, D2>;
+template <int D1, int D2, int D3, typename T>
+using Vector3D = Vector<T, D1, D2, D3>;
+template <int D1, int D2, int D3, int D4, typename T>
+using Vector4D = Vector<T, D1, D2, D3, D4>;
+
+////////////////////////////////////////////////////////////////////////////////
+// Templated instantiation follows.
+////////////////////////////////////////////////////////////////////////////////
+namespace impl {
+template <typename T, int M, int... Dims>
+std::ostream &operator<<(std::ostream &os, const Vector<T, M, Dims...> &v);
+
+template <int... Dims> struct StaticSizeMult {
+  static constexpr int value = 1;
+};
+
+template <int N, int... Dims> struct StaticSizeMult<N, Dims...> {
+  static constexpr int value = N * StaticSizeMult<Dims...>::value;
+};
+
+static inline void printSpace(std::ostream &os, int count) {
+  for (int i = 0; i < count; ++i) {
+    os << ' ';
+  }
+}
+
+template <typename T, int M, int... Dims> struct VectorDataPrinter {
+  static void print(std::ostream &os, const Vector<T, M, Dims...> &val);
+};
+
+template <typename T, int M, int... Dims>
+void VectorDataPrinter<T, M, Dims...>::print(std::ostream &os,
+                                             const Vector<T, M, Dims...> &val) {
+  static_assert(M > 0, "0 dimensioned tensor");
+  static_assert(sizeof(val) == M * StaticSizeMult<Dims...>::value * sizeof(T),
+                "Incorrect vector size!");
+  // First
+  os << "(" << val.vector[0];
+  if (M > 1)
+    os << ", ";
+  if (sizeof...(Dims) > 1)
+    os << "\n";
+  // Kernel
+  for (unsigned i = 1; i + 1 < M; ++i) {
+    printSpace(os, 2 * sizeof...(Dims));
+    os << val.vector[i] << ", ";
+    if (sizeof...(Dims) > 1)
+      os << "\n";
+  }
+  // Last
+  if (M > 1) {
+    printSpace(os, sizeof...(Dims));
+    os << val.vector[M - 1];
+  }
+  os << ")";
+}
+
+template <typename T, int M, int... Dims>
+std::ostream &operator<<(std::ostream &os, const Vector<T, M, Dims...> &v) {
+  VectorDataPrinter<T, M, Dims...>::print(os, v);
+  return os;
+}
+
+template <typename T, int N> struct MemRefDataPrinter {
+  static void print(std::ostream &os, T *base, int64_t rank, int64_t offset,
+                    int64_t *sizes, int64_t *strides);
+  static void printFirst(std::ostream &os, T *base, int64_t rank,
+                         int64_t offset, int64_t *sizes, int64_t *strides);
+  static void printLast(std::ostream &os, T *base, int64_t rank, int64_t offset,
+                        int64_t *sizes, int64_t *strides);
+};
+
+template <typename T> struct MemRefDataPrinter<T, 0> {
+  static void print(std::ostream &os, T *base, int64_t rank, int64_t offset,
+                    int64_t *sizes = nullptr, int64_t *strides = nullptr);
+};
+
+template <typename T, int N>
+void MemRefDataPrinter<T, N>::printFirst(std::ostream &os, T *base,
+                                         int64_t rank, int64_t offset,
+                                         int64_t *sizes, int64_t *strides) {
+  os << "[";
+  MemRefDataPrinter<T, N - 1>::print(os, base, rank, offset, sizes + 1,
+                                     strides + 1);
+  // If single element, close square bracket and return early.
+  if (sizes[0] <= 1) {
+    os << "]";
+    return;
+  }
+  os << ", ";
+  if (N > 1)
+    os << "\n";
+}
+
+template <typename T, int N>
+void MemRefDataPrinter<T, N>::print(std::ostream &os, T *base, int64_t rank,
+                                    int64_t offset, int64_t *sizes,
+                                    int64_t *strides) {
+  printFirst(os, base, rank, offset, sizes, strides);
+  for (unsigned i = 1; i + 1 < sizes[0]; ++i) {
+    printSpace(os, rank - N + 1);
+    MemRefDataPrinter<T, N - 1>::print(os, base, rank, offset + i * strides[0],
+                                       sizes + 1, strides + 1);
+    os << ", ";
+    if (N > 1)
+      os << "\n";
+  }
+  if (sizes[0] <= 1)
+    return;
+  printLast(os, base, rank, offset, sizes, strides);
+}
+
+template <typename T, int N>
+void MemRefDataPrinter<T, N>::printLast(std::ostream &os, T *base, int64_t rank,
+                                        int64_t offset, int64_t *sizes,
+                                        int64_t *strides) {
+  printSpace(os, rank - N + 1);
+  MemRefDataPrinter<T, N - 1>::print(os, base, rank,
+                                     offset + (sizes[0] - 1) * (*strides),
+                                     sizes + 1, strides + 1);
+  os << "]";
+}
+
+template <typename T>
+void MemRefDataPrinter<T, 0>::print(std::ostream &os, T *base, int64_t rank,
+                                    int64_t offset, int64_t *sizes,
+                                    int64_t *strides) {
+  os << base[offset];
+}
+
+template <typename T, int N> void printMemRef(StridedMemRefType<T, N> &M) {
+  static_assert(N > 0, "Expected N > 0");
+  printMemRefMetaData(std::cout, M);
+  std::cout << " data = " << std::endl;
+  MemRefDataPrinter<T, N>::print(std::cout, M.data, N, M.offset, M.sizes,
+                                 M.strides);
+  std::cout << std::endl;
+}
+
+template <typename T> void printMemRef(StridedMemRefType<T, 0> &M) {
+  printMemRefMetaData(std::cout, M);
+  std::cout << " data = " << std::endl;
+  std::cout << "[";
+  MemRefDataPrinter<T, 0>::print(std::cout, M.data, 0, M.offset);
+  std::cout << "]" << std::endl;
+}
+} // namespace impl
+
+////////////////////////////////////////////////////////////////////////////////
+// Currently exposed C API.
+////////////////////////////////////////////////////////////////////////////////
+extern "C" MLIR_RUNNER_UTILS_EXPORT void
+_mlir_ciface_print_memref_i8(UnrankedMemRefType<int8_t> *M);
+extern "C" MLIR_RUNNER_UTILS_EXPORT void
+_mlir_ciface_print_memref_f32(UnrankedMemRefType<float> *M);
+
+extern "C" MLIR_RUNNER_UTILS_EXPORT void print_memref_f32(int64_t rank,
+                                                          void *ptr);
+
+extern "C" MLIR_RUNNER_UTILS_EXPORT void
+_mlir_ciface_print_memref_0d_f32(StridedMemRefType<float, 0> *M);
+extern "C" MLIR_RUNNER_UTILS_EXPORT void
+_mlir_ciface_print_memref_1d_f32(StridedMemRefType<float, 1> *M);
+extern "C" MLIR_RUNNER_UTILS_EXPORT void
+_mlir_ciface_print_memref_2d_f32(StridedMemRefType<float, 2> *M);
+extern "C" MLIR_RUNNER_UTILS_EXPORT void
+_mlir_ciface_print_memref_3d_f32(StridedMemRefType<float, 3> *M);
+extern "C" MLIR_RUNNER_UTILS_EXPORT void
+_mlir_ciface_print_memref_4d_f32(StridedMemRefType<float, 4> *M);
+
+extern "C" MLIR_RUNNER_UTILS_EXPORT void
+_mlir_ciface_print_memref_vector_4x4xf32(
+    StridedMemRefType<Vector2D<4, 4, float>, 2> *M);
+
+// Small runtime support "lib" for vector.print lowering.
+extern "C" MLIR_RUNNER_UTILS_EXPORT void print_f32(float f);
+extern "C" MLIR_RUNNER_UTILS_EXPORT void print_f64(double d);
+extern "C" MLIR_RUNNER_UTILS_EXPORT void print_open();
+extern "C" MLIR_RUNNER_UTILS_EXPORT void print_close();
+extern "C" MLIR_RUNNER_UTILS_EXPORT void print_comma();
+extern "C" MLIR_RUNNER_UTILS_EXPORT void print_newline();
+
+#endif // MLIR_CPU_RUNNER_MLIRUTILS_H_