Skip to main content

scirs2_numpy/
array_protocol.rs

1//! NumPy array protocol (`__array__` and `__array_interface__`) support.
2//!
3//! The NumPy array protocol enables Python objects to be converted to NumPy
4//! arrays via two mechanisms:
5//!
6//! - `__array__(dtype=None)` — a method that returns a `numpy.ndarray`.
7//! - `__array_interface__` — a property returning a Python dict describing
8//!   the underlying buffer (shape, dtype string, raw pointer, etc.).
9//!
10//! The array interface dictionary format follows:
11//! <https://numpy.org/doc/stable/reference/arrays.interface.html>
12
13use pyo3::prelude::*;
14use pyo3::types::{PyDict, PyList, PyTuple};
15use thiserror::Error;
16
17// ─── Error types ────────────────────────────────────────────────────────────
18
19/// Errors produced by the array-protocol layer.
20#[derive(Debug, Error)]
21pub enum ArrayProtocolError {
22    /// The requested element dtype is not supported.
23    #[error("unsupported dtype: {0}")]
24    UnsupportedDtype(String),
25
26    /// The numpy typestr string could not be parsed.
27    #[error("invalid typestr: {0}")]
28    InvalidTypestr(String),
29
30    /// A Python API call failed.
31    #[error("python error: {0}")]
32    PythonError(String),
33}
34
35impl From<PyErr> for ArrayProtocolError {
36    fn from(e: PyErr) -> Self {
37        Self::PythonError(e.to_string())
38    }
39}
40
41impl From<ArrayProtocolError> for PyErr {
42    fn from(e: ArrayProtocolError) -> Self {
43        pyo3::exceptions::PyValueError::new_err(e.to_string())
44    }
45}
46
47// ─── parse_typestr ───────────────────────────────────────────────────────────
48
49/// Parse a NumPy type-string into `(kind_char, byte_count)`.
50///
51/// NumPy typestrings have the format `<endian><kind><bytes>`, where:
52/// - endianness: `'<'` (little), `'>'` (big), `'='` (native), `'|'` (n/a)
53/// - kind: `'f'` float, `'i'` signed int, `'u'` unsigned int, `'b'` bool,
54///   `'c'` complex, etc.
55/// - bytes: decimal byte count, e.g. `8` for 64-bit.
56///
57/// Returns `(kind, byte_count)`.
58///
59/// # Examples
60///
61/// ```
62/// use scirs2_numpy::array_protocol::parse_typestr;
63/// let (kind, bytes) = parse_typestr("<f8").unwrap();
64/// assert_eq!(kind, 'f');
65/// assert_eq!(bytes, 8);
66/// ```
67pub fn parse_typestr(typestr: &str) -> Result<(char, usize), ArrayProtocolError> {
68    if typestr.len() < 3 {
69        return Err(ArrayProtocolError::InvalidTypestr(format!(
70            "too short: {typestr:?}"
71        )));
72    }
73    let mut chars = typestr.chars();
74    let endian = chars
75        .next()
76        .ok_or_else(|| ArrayProtocolError::InvalidTypestr(format!("empty typestr: {typestr:?}")))?;
77    // Validate endianness character.
78    if !matches!(endian, '<' | '>' | '=' | '|') {
79        return Err(ArrayProtocolError::InvalidTypestr(format!(
80            "unknown endianness character {endian:?} in {typestr:?}"
81        )));
82    }
83    let kind = chars.next().ok_or_else(|| {
84        ArrayProtocolError::InvalidTypestr(format!("missing kind in {typestr:?}"))
85    })?;
86    if !kind.is_ascii_alphabetic() {
87        return Err(ArrayProtocolError::InvalidTypestr(format!(
88            "invalid kind character {kind:?} in {typestr:?}"
89        )));
90    }
91    let size_str: String = chars.collect();
92    let byte_count = size_str.parse::<usize>().map_err(|_| {
93        ArrayProtocolError::InvalidTypestr(format!(
94            "invalid byte count {size_str:?} in {typestr:?}"
95        ))
96    })?;
97    if byte_count == 0 {
98        return Err(ArrayProtocolError::InvalidTypestr(format!(
99            "byte count must be > 0 in {typestr:?}"
100        )));
101    }
102    Ok((kind, byte_count))
103}
104
105// ─── ArrayProtocol trait ────────────────────────────────────────────────────
106
107/// Mixin trait for types that support the NumPy array interface protocol.
108///
109/// Implementors must provide shape, stride, type-string, and a raw data
110/// pointer, from which a complete [`ArrayInterfaceDict`] can be assembled.
111pub trait ArrayProtocol {
112    /// Returns the populated [`ArrayInterfaceDict`] for this object.
113    fn array_interface(&self) -> ArrayInterfaceDict;
114
115    /// Returns the NumPy dtype type-string (e.g. `"<f8"` for little-endian f64).
116    fn dtype_str(&self) -> &'static str;
117
118    /// Returns the logical shape of the array.
119    fn shape(&self) -> Vec<usize>;
120
121    /// Returns the strides of the array **in bytes**.
122    fn strides(&self) -> Vec<usize>;
123
124    /// Returns a raw pointer to the first byte of element data.
125    fn data_ptr(&self) -> *const u8;
126
127    /// Returns the total number of bytes occupied by the element buffer.
128    fn nbytes(&self) -> usize;
129}
130
131// ─── ArrayInterfaceDict ──────────────────────────────────────────────────────
132
133/// Data for the `__array_interface__` protocol dictionary.
134///
135/// See: <https://numpy.org/doc/stable/reference/arrays.interface.html>
136pub struct ArrayInterfaceDict {
137    /// Logical shape of the array.
138    pub shape: Vec<usize>,
139    /// NumPy dtype typestr (e.g. `"<f8"`).
140    pub typestr: String,
141    /// Raw pointer to element data, encoded as a Python integer.
142    pub data_ptr: usize,
143    /// Whether the buffer should be treated as read-only.
144    pub readonly: bool,
145    /// Optional per-dimension strides in bytes.
146    pub strides: Option<Vec<usize>>,
147    /// Protocol version; always 3.
148    pub version: u8,
149}
150
151impl ArrayInterfaceDict {
152    /// Serialize this descriptor into a Python dict suitable for `__array_interface__`.
153    ///
154    /// The resulting dict has the keys `shape`, `typestr`, `data`, `version`,
155    /// and optionally `strides`.
156    pub fn to_py_dict<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
157        let dict = PyDict::new(py);
158
159        // shape — tuple of usize
160        let shape_tuple = PyTuple::new(py, self.shape.iter().copied())?;
161        dict.set_item("shape", shape_tuple)?;
162
163        // typestr — str
164        dict.set_item("typestr", &self.typestr)?;
165
166        // data — (ptr_as_int, readonly_bool)
167        let data_tuple = PyTuple::new(py, [self.data_ptr, self.readonly as usize])?;
168        dict.set_item("data", data_tuple)?;
169
170        // version — always 3
171        dict.set_item("version", self.version)?;
172
173        // strides — optional tuple of usize
174        if let Some(ref strides) = self.strides {
175            let strides_tuple = PyTuple::new(py, strides.iter().copied())?;
176            dict.set_item("strides", strides_tuple)?;
177        }
178
179        Ok(dict)
180    }
181}
182
183// ─── NdArrayWrapper ──────────────────────────────────────────────────────────
184
185/// A concrete array type implementing the NumPy `__array__` and
186/// `__array_interface__` protocols.
187///
188/// Wraps an owned flat `Vec<f64>` buffer with a logical shape, and exposes
189/// it to NumPy via the array interface protocol.
190#[pyclass(name = "NdArrayWrapper")]
191pub struct NdArrayWrapper {
192    /// Flat element buffer in C (row-major) order.
193    data: Vec<f64>,
194    /// Logical shape.
195    shape: Vec<usize>,
196    /// Per-dimension strides **in bytes** (C-contiguous by default).
197    strides: Vec<usize>,
198    /// NumPy dtype typestr.
199    dtype: String,
200}
201
202#[pymethods]
203impl NdArrayWrapper {
204    /// Construct a new `NdArrayWrapper` with C-contiguous strides.
205    ///
206    /// # Arguments
207    /// * `data`  – flat element buffer; must have `shape.iter().product::<usize>()` elements.
208    /// * `shape` – logical dimensions.
209    #[new]
210    pub fn new(data: Vec<f64>, shape: Vec<usize>) -> PyResult<Self> {
211        let n: usize = shape.iter().product();
212        if data.len() != n {
213            return Err(pyo3::exceptions::PyValueError::new_err(format!(
214                "data length {} does not match shape product {}",
215                data.len(),
216                n
217            )));
218        }
219        let strides = compute_c_strides_bytes(&shape, std::mem::size_of::<f64>());
220        Ok(Self {
221            data,
222            shape,
223            strides,
224            dtype: "<f8".to_owned(),
225        })
226    }
227
228    /// Return a Python representation suitable for numpy consumption.
229    ///
230    /// Calls `numpy.array(list_of_floats).reshape(shape)` so that consumers
231    /// that call `np.asarray(obj)` or `np.array(obj.__array__())` obtain the
232    /// correct array.
233    ///
234    /// Note: requires NumPy to be installed in the active Python environment.
235    #[pyo3(name = "__array__")]
236    pub fn array_method(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
237        let np = py.import("numpy").map_err(|e| {
238            pyo3::exceptions::PyImportError::new_err(format!("numpy not available: {e}"))
239        })?;
240        // Build a flat Python list from the data buffer.
241        let flat_list = PyList::new(py, &self.data)?;
242        // numpy.array(flat_list, dtype='f8')
243        let kwargs = PyDict::new(py);
244        kwargs.set_item("dtype", "f8")?;
245        let arr = np.call_method("array", (flat_list,), Some(&kwargs))?;
246        // Reshape to logical shape.
247        let shape_tuple = PyTuple::new(py, self.shape.iter().copied())?;
248        let reshaped = arr.call_method1("reshape", (shape_tuple,))?;
249        Ok(reshaped.unbind())
250    }
251
252    /// The `__array_interface__` property, returning a dict describing the buffer.
253    #[getter]
254    pub fn array_interface(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
255        let desc = ArrayInterfaceDict {
256            shape: self.shape.clone(),
257            typestr: self.dtype.clone(),
258            data_ptr: self.data.as_ptr() as usize,
259            readonly: true,
260            strides: Some(self.strides.clone()),
261            version: 3,
262        };
263        let dict = desc.to_py_dict(py)?;
264        Ok(dict.into_any().unbind())
265    }
266
267    /// Return the shape as a Python tuple.
268    pub fn shape_tuple(&self, py: Python<'_>) -> Py<PyAny> {
269        PyTuple::new(py, self.shape.iter().copied())
270            .map(|t| t.into_any().unbind())
271            .unwrap_or_else(|_| py.None())
272    }
273
274    /// Return the dtype typestr (e.g. `"<f8"`).
275    pub fn dtype_str(&self) -> &str {
276        &self.dtype
277    }
278
279    /// Return a flat copy of the data buffer.
280    pub fn data(&self) -> Vec<f64> {
281        self.data.clone()
282    }
283
284    /// Return the number of dimensions.
285    pub fn ndim(&self) -> usize {
286        self.shape.len()
287    }
288}
289
290impl ArrayProtocol for NdArrayWrapper {
291    fn array_interface(&self) -> ArrayInterfaceDict {
292        ArrayInterfaceDict {
293            shape: self.shape.clone(),
294            typestr: self.dtype.clone(),
295            data_ptr: self.data.as_ptr() as usize,
296            readonly: true,
297            strides: Some(self.strides.clone()),
298            version: 3,
299        }
300    }
301
302    fn dtype_str(&self) -> &'static str {
303        "<f8"
304    }
305
306    fn shape(&self) -> Vec<usize> {
307        self.shape.clone()
308    }
309
310    fn strides(&self) -> Vec<usize> {
311        self.strides.clone()
312    }
313
314    fn data_ptr(&self) -> *const u8 {
315        self.data.as_ptr() as *const u8
316    }
317
318    fn nbytes(&self) -> usize {
319        self.data.len() * std::mem::size_of::<f64>()
320    }
321}
322
323// ─── Register ───────────────────────────────────────────────────────────────
324
325/// Register array-protocol classes into a PyO3 module.
326pub fn register_array_protocol_module(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
327    m.add_class::<NdArrayWrapper>()?;
328    Ok(())
329}
330
331// ─── Helpers ─────────────────────────────────────────────────────────────────
332
333/// Compute C-contiguous (row-major) strides in bytes for a given shape.
334///
335/// The last dimension has stride `elem_size`; each preceding dimension has stride
336/// equal to the product of all following dimensions multiplied by `elem_size`.
337fn compute_c_strides_bytes(shape: &[usize], elem_size: usize) -> Vec<usize> {
338    let n = shape.len();
339    if n == 0 {
340        return Vec::new();
341    }
342    let mut strides = vec![elem_size; n];
343    for i in (0..n - 1).rev() {
344        strides[i] = strides[i + 1] * shape[i + 1];
345    }
346    strides
347}
348
349// ─── Tests ───────────────────────────────────────────────────────────────────
350
351#[cfg(test)]
352mod tests {
353    use super::*;
354
355    // --- parse_typestr ---
356
357    #[test]
358    fn test_parse_typestr_f64_le() {
359        let (kind, bytes) = parse_typestr("<f8").expect("parse_typestr failed");
360        assert_eq!(kind, 'f');
361        assert_eq!(bytes, 8);
362    }
363
364    #[test]
365    fn test_parse_typestr_i32_be() {
366        let (kind, bytes) = parse_typestr(">i4").expect("parse_typestr failed");
367        assert_eq!(kind, 'i');
368        assert_eq!(bytes, 4);
369    }
370
371    #[test]
372    fn test_parse_typestr_u16_native() {
373        let (kind, bytes) = parse_typestr("=u2").expect("parse_typestr failed");
374        assert_eq!(kind, 'u');
375        assert_eq!(bytes, 2);
376    }
377
378    #[test]
379    fn test_parse_typestr_bool_noendian() {
380        let (kind, bytes) = parse_typestr("|b1").expect("parse_typestr failed");
381        assert_eq!(kind, 'b');
382        assert_eq!(bytes, 1);
383    }
384
385    #[test]
386    fn test_parse_typestr_error_too_short() {
387        assert!(parse_typestr("<f").is_err());
388        assert!(parse_typestr("").is_err());
389        assert!(parse_typestr("<").is_err());
390    }
391
392    #[test]
393    fn test_parse_typestr_error_bad_endian() {
394        assert!(parse_typestr("?f8").is_err());
395    }
396
397    #[test]
398    fn test_parse_typestr_error_zero_bytes() {
399        assert!(parse_typestr("<f0").is_err());
400    }
401
402    // --- ArrayInterfaceDict ---
403
404    #[test]
405    fn test_array_interface_dict_version() {
406        let data = vec![1.0_f64, 2.0, 3.0, 4.0];
407        let wrapper = NdArrayWrapper::new(data, vec![2, 2]).expect("NdArrayWrapper::new failed");
408        let iface = ArrayProtocol::array_interface(&wrapper);
409        assert_eq!(iface.version, 3, "version must be 3");
410    }
411
412    #[test]
413    fn test_array_interface_dict_shape() {
414        let data = vec![1.0_f64; 6];
415        let wrapper = NdArrayWrapper::new(data, vec![2, 3]).expect("NdArrayWrapper::new failed");
416        let iface = ArrayProtocol::array_interface(&wrapper);
417        assert_eq!(iface.shape, vec![2, 3]);
418    }
419
420    #[test]
421    fn test_array_interface_dict_typestr() {
422        let data = vec![0.0_f64; 4];
423        let wrapper = NdArrayWrapper::new(data, vec![4]).expect("NdArrayWrapper::new failed");
424        let iface = ArrayProtocol::array_interface(&wrapper);
425        assert_eq!(iface.typestr, "<f8");
426    }
427
428    #[test]
429    fn test_array_interface_dict_data_ptr_nonzero() {
430        let data = vec![1.0_f64, 2.0, 3.0];
431        let wrapper = NdArrayWrapper::new(data, vec![3]).expect("NdArrayWrapper::new failed");
432        let iface = ArrayProtocol::array_interface(&wrapper);
433        assert_ne!(iface.data_ptr, 0, "data pointer must be non-null");
434    }
435
436    // --- NdArrayWrapper construction ---
437
438    #[test]
439    fn test_ndarray_wrapper_shape_mismatch() {
440        // data has 4 elements but shape says 6
441        let result = NdArrayWrapper::new(vec![1.0; 4], vec![2, 3]);
442        assert!(result.is_err());
443    }
444
445    #[test]
446    fn test_ndarray_wrapper_scalar() {
447        // 0-d equivalent: shape = [1]
448        let wrapper = NdArrayWrapper::new(vec![42.0], vec![1]).expect("scalar failed");
449        assert_eq!(wrapper.ndim(), 1);
450        assert_eq!(wrapper.data(), vec![42.0]);
451    }
452
453    #[test]
454    fn test_ndarray_wrapper_strides_c_order() {
455        // shape [3, 4] → strides [32, 8] (in bytes, f64=8)
456        let data = vec![0.0_f64; 12];
457        let wrapper = NdArrayWrapper::new(data, vec![3, 4]).expect("NdArrayWrapper::new failed");
458        let strides = ArrayProtocol::strides(&wrapper);
459        assert_eq!(strides, vec![32, 8]);
460    }
461
462    // --- Python-GIL tests (require auto-initialize feature) ---
463
464    #[test]
465    fn test_array_interface_py_dict_keys() {
466        Python::attach(|py| {
467            let data = vec![1.0_f64, 2.0, 3.0, 4.0];
468            let wrapper =
469                NdArrayWrapper::new(data, vec![2, 2]).expect("NdArrayWrapper::new failed");
470            let iface = ArrayProtocol::array_interface(&wrapper);
471            let dict = iface.to_py_dict(py).expect("to_py_dict failed");
472
473            assert!(dict
474                .get_item("shape")
475                .expect("shape lookup failed")
476                .is_some());
477            assert!(dict
478                .get_item("typestr")
479                .expect("typestr lookup failed")
480                .is_some());
481            assert!(dict.get_item("data").expect("data lookup failed").is_some());
482            assert!(dict
483                .get_item("version")
484                .expect("version lookup failed")
485                .is_some());
486        });
487    }
488
489    #[test]
490    fn test_array_interface_py_dict_shape_values() {
491        Python::attach(|py| {
492            let data = vec![0.0_f64; 6];
493            let wrapper =
494                NdArrayWrapper::new(data, vec![2, 3]).expect("NdArrayWrapper::new failed");
495            let iface = ArrayProtocol::array_interface(&wrapper);
496            let dict = iface.to_py_dict(py).expect("to_py_dict failed");
497
498            let shape_obj = dict
499                .get_item("shape")
500                .expect("shape lookup failed")
501                .expect("shape missing");
502            let shape_tuple = shape_obj.cast::<PyTuple>().expect("shape is not a tuple");
503            assert_eq!(shape_tuple.len(), 2);
504            let v0: usize = shape_tuple
505                .get_item(0)
506                .expect("item 0")
507                .extract()
508                .expect("extract[0]");
509            let v1: usize = shape_tuple
510                .get_item(1)
511                .expect("item 1")
512                .extract()
513                .expect("extract[1]");
514            assert_eq!(v0, 2);
515            assert_eq!(v1, 3);
516        });
517    }
518}