scirs2_numpy/array_subclass.rs
1//! Array subclass support for duck-typed Python objects.
2//!
3//! Provides utilities to accept any Python object that looks like an array:
4//! - NumPy ndarrays (via `__array__` protocol)
5//! - pandas Series (has a `.values` attribute returning an ndarray)
6//! - Any list-like object supporting `__len__` and `__getitem__`
7//!
8//! Also exposes [`SubclassArrayWrapper`], a `#[pyclass]` that wraps a flat
9//! `f64` buffer with shape metadata and looks enough like a NumPy array for
10//! downstream code that uses duck-typing.
11
12use pyo3::prelude::*;
13use pyo3::types::PyAnyMethods;
14
15// ──────────────────────────────────────────────────────────────────────────────
16// Free-standing extraction helpers
17// ──────────────────────────────────────────────────────────────────────────────
18
19/// Extract `f32` values from any Python array-like object.
20///
21/// Attempts the following strategies in order:
22/// 1. `.values` attribute (pandas Series / masked array).
23/// 2. `.__array__()` method (NumPy array protocol).
24/// 3. Direct iteration via `__len__` + `__getitem__`.
25///
26/// # Errors
27/// Returns a [`PyErr`] if none of the strategies succeeds or if an element
28/// cannot be converted to `f32`.
29#[pyfunction]
30pub fn from_array_like_f32(obj: &Bound<'_, PyAny>) -> PyResult<Vec<f32>> {
31 // Strategy 1: .values attribute (pandas Series / masked array).
32 // Guard: only recurse if the attribute is NOT already a plain ndarray
33 // (to avoid infinite recursion on ndarray objects that happen to lack .values).
34 if let Ok(values) = obj.getattr("values") {
35 // If `values` itself has a .values attribute we stop (avoid deep recursion).
36 if values.getattr("values").is_err() {
37 return from_array_like_f32(&values);
38 }
39 }
40
41 // Strategy 2: __array__ protocol.
42 if let Ok(arr) = obj.call_method0("__array__") {
43 // Recursing here is safe: the resulting ndarray has no .values attribute.
44 return from_array_like_f32(&arr);
45 }
46
47 // Strategy 3: direct iteration.
48 let len = obj.len()?;
49 let mut result = Vec::with_capacity(len);
50 for i in 0..len {
51 let item = obj.get_item(i)?;
52 let val: f32 = item.extract()?;
53 result.push(val);
54 }
55 Ok(result)
56}
57
58/// Extract `f64` values from any Python array-like object.
59///
60/// Attempts the following strategies in order:
61/// 1. `.values` attribute (pandas Series / masked array).
62/// 2. `.__array__()` method (NumPy array protocol).
63/// 3. Direct iteration via `__len__` + `__getitem__`.
64///
65/// # Errors
66/// Returns a [`PyErr`] if none of the strategies succeeds or if an element
67/// cannot be converted to `f64`.
68#[pyfunction]
69pub fn from_array_like_f64(obj: &Bound<'_, PyAny>) -> PyResult<Vec<f64>> {
70 // Strategy 1: .values attribute — with depth guard.
71 if let Ok(values) = obj.getattr("values") {
72 if values.getattr("values").is_err() {
73 return from_array_like_f64(&values);
74 }
75 }
76
77 // Strategy 2: __array__ protocol.
78 if let Ok(arr) = obj.call_method0("__array__") {
79 return from_array_like_f64(&arr);
80 }
81
82 // Strategy 3: direct iteration.
83 let len = obj.len()?;
84 let mut result = Vec::with_capacity(len);
85 for i in 0..len {
86 let item = obj.get_item(i)?;
87 let val: f64 = item.extract()?;
88 result.push(val);
89 }
90 Ok(result)
91}
92
93// ──────────────────────────────────────────────────────────────────────────────
94// SubclassArrayWrapper
95// ──────────────────────────────────────────────────────────────────────────────
96
97/// A Python-visible wrapper around a flat `f64` data buffer with shape metadata.
98///
99/// Implements enough of the NumPy duck-typing surface to be accepted by code
100/// that inspects `.shape`, `.dtype`, `.__len__`, and `.__getitem__`.
101#[pyclass(name = "SubclassArrayWrapper")]
102pub struct SubclassArrayWrapper {
103 /// Flat data buffer in C (row-major) order.
104 data: Vec<f64>,
105 /// Logical shape of the array.
106 shape: Vec<usize>,
107 /// NumPy-compatible dtype string (e.g. `"float64"`).
108 dtype: String,
109}
110
111#[pymethods]
112impl SubclassArrayWrapper {
113 /// Construct a new wrapper.
114 ///
115 /// # Arguments
116 /// * `data` – flat element buffer; length must equal the product of `shape`.
117 /// * `shape` – logical shape; `[]` is interpreted as a 0-d scalar.
118 /// * `dtype` – NumPy-compatible dtype string such as `"float64"`.
119 #[new]
120 #[pyo3(signature = (data, shape, dtype = "float64".to_string()))]
121 pub fn new(data: Vec<f64>, shape: Vec<usize>, dtype: String) -> PyResult<Self> {
122 let n: usize = shape.iter().product::<usize>().max(1);
123 if shape.is_empty() {
124 // 0-d: exactly one element.
125 if data.len() != 1 {
126 return Err(pyo3::exceptions::PyValueError::new_err(
127 "0-d SubclassArrayWrapper requires exactly one element",
128 ));
129 }
130 } else if data.len() != n {
131 return Err(pyo3::exceptions::PyValueError::new_err(format!(
132 "data length {} does not match shape product {}",
133 data.len(),
134 n
135 )));
136 }
137 Ok(Self { data, shape, dtype })
138 }
139
140 /// Number of elements (flat).
141 pub fn __len__(&self) -> usize {
142 self.data.len()
143 }
144
145 /// Return element at flat index `idx`.
146 ///
147 /// # Errors
148 /// Returns `IndexError` if `idx` is out of bounds.
149 pub fn __getitem__(&self, idx: usize) -> PyResult<f64> {
150 self.data.get(idx).copied().ok_or_else(|| {
151 pyo3::exceptions::PyIndexError::new_err(format!(
152 "index {} out of bounds for array of length {}",
153 idx,
154 self.data.len()
155 ))
156 })
157 }
158
159 /// Return `self` — mirrors pandas `.values` which returns the underlying array.
160 ///
161 /// This makes `SubclassArrayWrapper` itself accepted by [`from_array_like_f64`].
162 pub fn values(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
163 slf
164 }
165
166 /// Logical shape tuple.
167 pub fn shape(&self) -> Vec<usize> {
168 self.shape.clone()
169 }
170
171 /// NumPy-compatible dtype string.
172 pub fn dtype(&self) -> &str {
173 &self.dtype
174 }
175
176 /// Flat copy of the data as a Python list.
177 pub fn to_list(&self) -> Vec<f64> {
178 self.data.clone()
179 }
180}
181
182// ──────────────────────────────────────────────────────────────────────────────
183// Module registration
184// ──────────────────────────────────────────────────────────────────────────────
185
186/// Register array-subclass helpers and [`SubclassArrayWrapper`] into a PyO3 module.
187pub fn register_array_subclass_module(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
188 m.add_function(wrap_pyfunction!(from_array_like_f32, m)?)?;
189 m.add_function(wrap_pyfunction!(from_array_like_f64, m)?)?;
190 m.add_class::<SubclassArrayWrapper>()?;
191 Ok(())
192}
193
194// ──────────────────────────────────────────────────────────────────────────────
195// Tests
196// ──────────────────────────────────────────────────────────────────────────────
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201
202 #[test]
203 fn array_like_extracts_from_list() {
204 Python::attach(|py| {
205 let list = py
206 .eval(pyo3::ffi::c_str!("[1.0, 2.0, 3.0]"), None, None)
207 .expect("eval failed");
208 let result = from_array_like_f64(&list).expect("extraction failed");
209 assert_eq!(result, vec![1.0, 2.0, 3.0]);
210 });
211 }
212
213 #[test]
214 fn array_like_wrapper_len_correct() {
215 let wrapper =
216 SubclassArrayWrapper::new(vec![1.0, 2.0, 3.0], vec![3], "float64".to_string())
217 .expect("construction failed");
218 assert_eq!(wrapper.__len__(), 3);
219 }
220
221 #[test]
222 fn subclass_wrapper_getitem_correct() {
223 let wrapper =
224 SubclassArrayWrapper::new(vec![10.0, 20.0, 30.0], vec![3], "float64".to_string())
225 .expect("construction failed");
226 assert!((wrapper.__getitem__(1).expect("index valid") - 20.0).abs() < f64::EPSILON);
227 }
228
229 #[test]
230 fn subclass_wrapper_getitem_oob() {
231 let wrapper = SubclassArrayWrapper::new(vec![1.0], vec![1], "float64".to_string())
232 .expect("construction failed");
233 assert!(wrapper.__getitem__(99).is_err());
234 }
235
236 #[test]
237 fn subclass_wrapper_shape_and_dtype() {
238 let wrapper = SubclassArrayWrapper::new(vec![1.0, 2.0], vec![2], "float64".to_string())
239 .expect("construction failed");
240 assert_eq!(wrapper.shape(), vec![2usize]);
241 assert_eq!(wrapper.dtype(), "float64");
242 }
243}