diff options
Diffstat (limited to 'sources/shiboken6/libshiboken/sbknumpyview.cpp')
-rw-r--r-- | sources/shiboken6/libshiboken/sbknumpyview.cpp | 265 |
1 files changed, 265 insertions, 0 deletions
diff --git a/sources/shiboken6/libshiboken/sbknumpyview.cpp b/sources/shiboken6/libshiboken/sbknumpyview.cpp new file mode 100644 index 000000000..bafbf8038 --- /dev/null +++ b/sources/shiboken6/libshiboken/sbknumpyview.cpp @@ -0,0 +1,265 @@ +// Copyright (C) 2021 The Qt Company Ltd. +// SPDX-License-Identifier: LicenseRef-Qt-Commercial OR LGPL-3.0-only OR GPL-2.0-only OR GPL-3.0-only + +// included by sbknumpy.cpp + +#include "helper.h" +#include <iostream> +#include <iomanip> +#include <optional> + +#ifdef HAVE_NUMPY + +namespace Shiboken { +namespace Numpy { + +static std::optional<View::Type> viewTypeFromNumPy(int npt) +{ + switch (npt) { + case NPY_SHORT: + return View::Int16; + case NPY_USHORT: + return View::Unsigned16; + case NPY_INT: + return View::Int; + case NPY_UINT: + return View::Unsigned; + case NPY_LONG: + if constexpr (sizeof(long) == sizeof(int)) + return View::Int; + if constexpr (sizeof(long) == sizeof(int64_t)) + return View::Int64; + break; + case NPY_ULONG: + if constexpr (sizeof(long) == sizeof(int)) + return View::Unsigned; + if constexpr (sizeof(long) == sizeof(int64_t)) + return View::Unsigned64; + break; + case NPY_LONGLONG: + if constexpr (sizeof(long long) == 8) + return View::Int64; + break; + case NPY_ULONGLONG: + if constexpr (sizeof(long long) == 8) + return View::Unsigned64; + break; + case NPY_FLOAT: + return View::Float; + case NPY_DOUBLE: + return View::Double; + default: + break; + } + return {}; +} + +View View::fromPyObject(PyObject *pyIn) +{ + if (pyIn == nullptr || PyArray_Check(pyIn) == 0) + return {}; + auto *ar = reinterpret_cast<PyArrayObject *>(pyIn); + if ((PyArray_FLAGS(ar) & NPY_ARRAY_C_CONTIGUOUS) == 0) + return {}; + const int ndim = PyArray_NDIM(ar); + if (ndim > 2) + return {}; + + const auto typeO = viewTypeFromNumPy(PyArray_TYPE(ar)); + if (!typeO.has_value()) + return {}; + + View result; + result.ndim = ndim; + result.type = typeO.value(); + result.data = PyArray_DATA(ar); + result.dimensions[0] = PyArray_DIMS(ar)[0]; + result.stride[0] = PyArray_STRIDES(ar)[0]; + if (ndim > 1) { + result.dimensions[1] = PyArray_DIMS(ar)[1]; + result.stride[1] = PyArray_STRIDES(ar)[1]; + } else { + result.dimensions[1] = result.stride[1] = 0; + } + return result; +} + +} // namespace Numpy + +template <class T> +static void debugArray(std::ostream &str, const T *data, int n) +{ + static const int maxData = 10; + str << " = "; + auto *end = data + std::min(n, maxData); + for (auto *d = data; d != end; ++d) { + if (d != data) + str << ", "; + str << *d; + } + if (n > maxData) + str << "..."; +} + +std::ostream &operator<<(std::ostream &str, const debugPyArrayObject &a) +{ + str << "PyArrayObject("; + if (a.m_object == nullptr) { + str << '0'; + } else if (PyArray_Check(a.m_object) != 0) { + auto *ar = reinterpret_cast<PyArrayObject *>(a.m_object); + const int ndim = PyArray_NDIM(ar); + const int type = PyArray_TYPE(ar); + const int flags = PyArray_FLAGS(ar); + str << "ndim=" << ndim << " ["; + for (int d = 0; d < ndim; ++d) { + if (d) + str << ", "; + str << PyArray_DIMS(ar)[d]; + } + str << "], type="; + switch (type) { + case NPY_SHORT: + str << "short"; + break; + case NPY_USHORT: + str << "ushort"; + break; + case NPY_INT: + str << "int32"; + break; + case NPY_UINT: + str << "uint32"; + break; + case NPY_LONG: + str << "long"; + break; + case NPY_ULONG: + str << "ulong"; + break; + case NPY_LONGLONG: + str << "long long"; + break; + case NPY_ULONGLONG: + str << "ulong long"; + break; + case NPY_FLOAT: + str << "float"; + break; + case NPY_DOUBLE: + str << "double"; + break; + default: + str << '(' << type << ')'; + break; + } + str << ", flags=0x" << std::hex << flags << std::dec; + if ((flags & NPY_ARRAY_C_CONTIGUOUS) != 0) + str << " [C-contiguous]"; + if ((flags & NPY_ARRAY_F_CONTIGUOUS) != 0) + str << " [Fortran-contiguous]"; + if ((flags & NPY_ARRAY_ALIGNED) != 0) + str << " [aligned]"; + if ((flags & NPY_ARRAY_OWNDATA) != 0) + str << " [owndata]"; + if ((flags & NPY_ARRAY_WRITEABLE) != 0) + str << " [writeable]"; + + if (const int dim0 = PyArray_DIMS(ar)[0]) { + auto *data = PyArray_DATA(ar); + switch (type) { + case NPY_SHORT: + debugArray(str, reinterpret_cast<const short *>(data), dim0); + break; + case NPY_USHORT: + debugArray(str, reinterpret_cast<const unsigned short *>(data), dim0); + break; + case NPY_INT: + debugArray(str, reinterpret_cast<const int *>(data), dim0); + break; + case NPY_UINT: + debugArray(str, reinterpret_cast<const unsigned *>(data), dim0); + break; + case NPY_LONG: + debugArray(str, reinterpret_cast<const long *>(data), dim0); + break; + case NPY_ULONG: + debugArray(str, reinterpret_cast<const unsigned long*>(data), dim0); + break; + case NPY_LONGLONG: + debugArray(str, reinterpret_cast<const long long *>(data), dim0); + break; + case NPY_ULONGLONG: + debugArray(str, reinterpret_cast<const unsigned long long *>(data), dim0); + break; + case NPY_FLOAT: + debugArray(str, reinterpret_cast<const float *>(data), dim0); + break; + case NPY_DOUBLE: + debugArray(str, reinterpret_cast<const double *>(data), dim0); + break; + } + } + } else { + str << "Invalid"; + } + str << ')'; + return str; +} + +} //namespace Shiboken + +#else // HAVE_NUMPY + +namespace Shiboken::Numpy +{ + +View View::fromPyObject(PyObject *) +{ + return {}; +} + +std::ostream &operator<<(std::ostream &str, const debugPyArrayObject &) +{ + str << "Unimplemented function " << __FUNCTION__ << ", (numpy was not found)."; + return str; +} + +} //namespace Shiboken::Numpy + +#endif // !HAVE_NUMPY + +namespace Shiboken::Numpy +{ + +bool View::sameLayout(const View &rhs) const +{ + return rhs && *this && ndim == rhs.ndim && type == rhs.type; +} + +bool View::sameSize(const View &rhs) const +{ + return sameLayout(rhs) + && dimensions[0] == rhs.dimensions[0] && dimensions[1] == rhs.dimensions[1]; +} + +std::ostream &operator<<(std::ostream &str, const View &v) +{ + str << "Shiboken::Numpy::View("; + if (v) { + str << "type=" << v.type << ", ndim=" << v.ndim << " [" + << v.dimensions[0]; + if (v.ndim > 1) + str << ", " << v.dimensions[1]; + str << "], stride=[" << v.stride[0]; + if (v.ndim > 1) + str << ", " << v.stride[1]; + str << "], data=" << v.data; + } else { + str << "invalid"; + } + str << ')'; + return str; +} + +} //namespace Shiboken::Numpy |