scirs2_numpy/array_like.rs
1use std::marker::PhantomData;
2use std::ops::Deref;
3
4use ndarray::{Array1, Dimension, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};
5use pyo3::{
6 intern,
7 sync::PyOnceLock,
8 types::{PyAnyMethods, PyDict},
9 Borrowed, FromPyObject, Py, PyAny, PyErr, PyResult,
10};
11
12use crate::array::PyArrayMethods;
13use crate::{get_array_module, Element, IntoPyArray, PyArray, PyReadonlyArray, PyUntypedArray};
14
15pub trait Coerce: Sealed {
16 const ALLOW_TYPE_CHANGE: bool;
17}
18
19mod sealed {
20 pub trait Sealed {}
21}
22
23use sealed::Sealed;
24
25/// Marker type to indicate that the element type received via [`PyArrayLike`] must match the specified type exactly.
26#[derive(Debug)]
27pub struct TypeMustMatch;
28
29impl Sealed for TypeMustMatch {}
30
31impl Coerce for TypeMustMatch {
32 const ALLOW_TYPE_CHANGE: bool = false;
33}
34
35/// Marker type to indicate that the element type received via [`PyArrayLike`] can be cast to the specified type by NumPy's [`asarray`](https://numpy.org/doc/stable/reference/generated/numpy.asarray.html).
36#[derive(Debug)]
37pub struct AllowTypeChange;
38
39impl Sealed for AllowTypeChange {}
40
41impl Coerce for AllowTypeChange {
42 const ALLOW_TYPE_CHANGE: bool = true;
43}
44
45/// Receiver for arrays or array-like types.
46///
47/// When building API using NumPy in Python, it is common for functions to additionally accept any array-like type such as `list[float]` as arguments.
48/// `PyArrayLike` enables the same pattern in Rust extensions, i.e. by taking this type as the argument of a `#[pyfunction]`,
49/// one will always get access to a [`PyReadonlyArray`] that will either reference to the NumPy array originally passed into the function
50/// or a temporary one created by converting the input type into a NumPy array.
51///
52/// Depending on whether [`TypeMustMatch`] or [`AllowTypeChange`] is used for the `C` type parameter,
53/// the element type must either match the specific type `T` exactly or will be cast to it by NumPy's [`asarray`](https://numpy.org/doc/stable/reference/generated/numpy.asarray.html).
54///
55/// # Example
56///
57/// `PyArrayLike1<'py, T, TypeMustMatch>` will enable you to receive both NumPy arrays and sequences
58///
59/// ```rust
60/// # use pyo3::prelude::*;
61/// use pyo3::py_run;
62/// use numpy::{get_array_module, PyArrayLike1, TypeMustMatch};
63///
64/// #[pyfunction]
65/// fn sum_up<'py>(py: Python<'py>, array: PyArrayLike1<'py, f64, TypeMustMatch>) -> f64 {
66/// array.as_array().sum()
67/// }
68///
69/// Python::attach(|py| {
70/// let np = get_array_module(py).unwrap();
71/// let sum_up = wrap_pyfunction!(sum_up)(py).unwrap();
72///
73/// py_run!(py, np sum_up, r"assert sum_up(np.array([1., 2., 3.])) == 6.");
74/// py_run!(py, np sum_up, r"assert sum_up((1., 2., 3.)) == 6.");
75/// });
76/// ```
77///
78/// but it will not cast the element type if that is required
79///
80/// ```rust,should_panic
81/// use pyo3::prelude::*;
82/// use pyo3::py_run;
83/// use numpy::{get_array_module, PyArrayLike1, TypeMustMatch};
84///
85/// #[pyfunction]
86/// fn sum_up<'py>(py: Python<'py>, array: PyArrayLike1<'py, i32, TypeMustMatch>) -> i32 {
87/// array.as_array().sum()
88/// }
89///
90/// Python::attach(|py| {
91/// let np = get_array_module(py).unwrap();
92/// let sum_up = wrap_pyfunction!(sum_up)(py).unwrap();
93///
94/// py_run!(py, np sum_up, r"assert sum_up((1., 2., 3.)) == 6");
95/// });
96/// ```
97///
98/// whereas `PyArrayLike1<'py, T, AllowTypeChange>` will do even at the cost loosing precision
99///
100/// ```rust
101/// use pyo3::prelude::*;
102/// use pyo3::py_run;
103/// use numpy::{get_array_module, AllowTypeChange, PyArrayLike1};
104///
105/// #[pyfunction]
106/// fn sum_up<'py>(py: Python<'py>, array: PyArrayLike1<'py, i32, AllowTypeChange>) -> i32 {
107/// array.as_array().sum()
108/// }
109///
110/// Python::attach(|py| {
111/// let np = get_array_module(py).unwrap();
112/// let sum_up = wrap_pyfunction!(sum_up)(py).unwrap();
113///
114/// py_run!(py, np sum_up, r"assert sum_up((1.5, 2.5)) == 3");
115/// });
116/// ```
117#[derive(Debug)]
118#[repr(transparent)]
119pub struct PyArrayLike<'py, T, D, C = TypeMustMatch>(PyReadonlyArray<'py, T, D>, PhantomData<C>)
120where
121 T: Element,
122 D: Dimension,
123 C: Coerce;
124
125impl<'py, T, D, C> Deref for PyArrayLike<'py, T, D, C>
126where
127 T: Element,
128 D: Dimension,
129 C: Coerce,
130{
131 type Target = PyReadonlyArray<'py, T, D>;
132
133 fn deref(&self) -> &Self::Target {
134 &self.0
135 }
136}
137
138impl<'a, 'py, T, D, C> FromPyObject<'a, 'py> for PyArrayLike<'py, T, D, C>
139where
140 T: Element + 'py,
141 D: Dimension + 'py,
142 C: Coerce,
143 Vec<T>: FromPyObject<'a, 'py>,
144{
145 type Error = PyErr;
146
147 fn extract(ob: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
148 if let Ok(array) = ob.cast::<PyArray<T, D>>() {
149 return Ok(Self(array.readonly(), PhantomData));
150 }
151
152 let py = ob.py();
153
154 // If the input is already an ndarray and `TypeMustMatch` is used then no type conversion
155 // should be performed.
156 if (C::ALLOW_TYPE_CHANGE || ob.cast::<PyUntypedArray>().is_err())
157 && matches!(D::NDIM, None | Some(1))
158 {
159 if let Ok(vec) = ob.extract::<Vec<T>>() {
160 let array = Array1::from(vec)
161 .into_dimensionality()
162 .expect("D being compatible to Ix1")
163 .into_pyarray(py)
164 .readonly();
165 return Ok(Self(array, PhantomData));
166 }
167 }
168
169 static AS_ARRAY: PyOnceLock<Py<PyAny>> = PyOnceLock::new();
170
171 let as_array = AS_ARRAY
172 .get_or_try_init(py, || {
173 get_array_module(py)?.getattr("asarray").map(Into::into)
174 })?
175 .bind(py);
176
177 let kwargs = if C::ALLOW_TYPE_CHANGE {
178 let kwargs = PyDict::new(py);
179 kwargs.set_item(intern!(py, "dtype"), T::get_dtype(py))?;
180 Some(kwargs)
181 } else {
182 None
183 };
184
185 let array = as_array.call((ob,), kwargs.as_ref())?.extract()?;
186 Ok(Self(array, PhantomData))
187 }
188}
189
190/// Receiver for zero-dimensional arrays or array-like types.
191pub type PyArrayLike0<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix0, C>;
192
193/// Receiver for one-dimensional arrays or array-like types.
194pub type PyArrayLike1<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix1, C>;
195
196/// Receiver for two-dimensional arrays or array-like types.
197pub type PyArrayLike2<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix2, C>;
198
199/// Receiver for three-dimensional arrays or array-like types.
200pub type PyArrayLike3<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix3, C>;
201
202/// Receiver for four-dimensional arrays or array-like types.
203pub type PyArrayLike4<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix4, C>;
204
205/// Receiver for five-dimensional arrays or array-like types.
206pub type PyArrayLike5<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix5, C>;
207
208/// Receiver for six-dimensional arrays or array-like types.
209pub type PyArrayLike6<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix6, C>;
210
211/// Receiver for arrays or array-like types whose dimensionality is determined at runtime.
212pub type PyArrayLikeDyn<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, IxDyn, C>;