Skip to main content

scirs2_numpy/
untyped.rs

1//! Runtime-typed array: a byte buffer with a runtime dtype descriptor.
2//!
3//! [`UntypedArray`] stores elements without a compile-time type parameter.
4//! This is useful in generic entry-point functions that need to inspect
5//! the dtype of an incoming array before deciding how to process it,
6//! analogous to `numpy.ndarray` with an unspecified element type.
7//!
8//! For the statically-typed `PyUntypedArray` wrapper around NumPy C-API
9//! objects, see `untyped_array`.
10
11use pyo3::prelude::*;
12
13/// A runtime-typed multi-dimensional array backed by a flat byte buffer.
14///
15/// Elements can be read and written as `f64` regardless of the underlying
16/// storage dtype — the conversion is applied automatically.
17#[pyclass(name = "UntypedArray")]
18pub struct UntypedArray {
19    /// Raw backing buffer; length is `n_elements * itemsize`.
20    data: Vec<u8>,
21    /// Canonical dtype name (e.g. `"float64"`, `"int32"`).
22    dtype_name: String,
23    /// Logical shape.
24    shape: Vec<usize>,
25    /// Byte size of a single element.
26    itemsize: usize,
27}
28
29#[pymethods]
30impl UntypedArray {
31    /// Construct a zero-filled untyped array.
32    ///
33    /// # Arguments
34    /// * `shape`      – logical dimensions.
35    /// * `dtype_name` – dtype string; one of: `"float32"`, `"f32"`, `"float64"`, `"f64"`,
36    ///                  `"int32"`, `"i32"`, `"int64"`, `"i64"`, `"bool"`, `"b"`,
37    ///                  `"uint8"`, `"u8"`, `"int8"`, `"i8"`.
38    #[new]
39    pub fn new(shape: Vec<usize>, dtype_name: String) -> PyResult<Self> {
40        let itemsize = resolve_itemsize(&dtype_name)?;
41        let n: usize = shape.iter().product::<usize>().max(1);
42        Ok(Self {
43            data: vec![0u8; n * itemsize],
44            dtype_name,
45            shape,
46            itemsize,
47        })
48    }
49
50    /// Return the canonical dtype name.
51    pub fn dtype_name(&self) -> &str {
52        &self.dtype_name
53    }
54
55    /// Return the byte size of a single element.
56    pub fn itemsize(&self) -> usize {
57        self.itemsize
58    }
59
60    /// Return the logical shape.
61    pub fn shape(&self) -> Vec<usize> {
62        self.shape.clone()
63    }
64
65    /// Return the number of dimensions.
66    pub fn ndim(&self) -> usize {
67        self.shape.len()
68    }
69
70    /// Return the total byte size of the backing buffer.
71    pub fn nbytes(&self) -> usize {
72        self.data.len()
73    }
74
75    /// Return the total number of elements (product of shape).
76    pub fn size(&self) -> usize {
77        self.shape.iter().product()
78    }
79
80    /// Return `true` if the element dtype is a floating-point type.
81    pub fn is_floating(&self) -> bool {
82        matches!(
83            self.dtype_name.as_str(),
84            "float32" | "f32" | "float64" | "f64"
85        )
86    }
87
88    /// Return `true` if the element dtype is an integer type.
89    pub fn is_integer(&self) -> bool {
90        matches!(
91            self.dtype_name.as_str(),
92            "int32" | "i32" | "int64" | "i64" | "int8" | "i8" | "uint8" | "u8"
93        )
94    }
95
96    /// Read element at `flat_index` and cast it to `f64`.
97    ///
98    /// Returns `PyIndexError` if the index is out of bounds.
99    pub fn read_as_f64(&self, flat_index: usize) -> PyResult<f64> {
100        let offset = flat_index * self.itemsize;
101        if offset + self.itemsize > self.data.len() {
102            return Err(pyo3::exceptions::PyIndexError::new_err(format!(
103                "flat_index {flat_index} is out of bounds"
104            )));
105        }
106        let value = match self.dtype_name.as_str() {
107            "float32" | "f32" => {
108                let bytes: [u8; 4] = self.data[offset..offset + 4].try_into().map_err(|_| {
109                    pyo3::exceptions::PyValueError::new_err("slice conversion error (f32)")
110                })?;
111                f32::from_le_bytes(bytes) as f64
112            }
113            "float64" | "f64" => {
114                let bytes: [u8; 8] = self.data[offset..offset + 8].try_into().map_err(|_| {
115                    pyo3::exceptions::PyValueError::new_err("slice conversion error (f64)")
116                })?;
117                f64::from_le_bytes(bytes)
118            }
119            "int32" | "i32" => {
120                let bytes: [u8; 4] = self.data[offset..offset + 4].try_into().map_err(|_| {
121                    pyo3::exceptions::PyValueError::new_err("slice conversion error (i32)")
122                })?;
123                i32::from_le_bytes(bytes) as f64
124            }
125            "int64" | "i64" => {
126                let bytes: [u8; 8] = self.data[offset..offset + 8].try_into().map_err(|_| {
127                    pyo3::exceptions::PyValueError::new_err("slice conversion error (i64)")
128                })?;
129                i64::from_le_bytes(bytes) as f64
130            }
131            "int8" | "i8" => self.data[offset] as i8 as f64,
132            "uint8" | "u8" | "bool" | "b" => self.data[offset] as f64,
133            _ => 0.0,
134        };
135        Ok(value)
136    }
137
138    /// Write `value` (as `f64`) to element at `flat_index`, casting to the array's dtype.
139    ///
140    /// Returns `PyIndexError` if the index is out of bounds.
141    pub fn write_f64(&mut self, flat_index: usize, value: f64) -> PyResult<()> {
142        let offset = flat_index * self.itemsize;
143        if offset + self.itemsize > self.data.len() {
144            return Err(pyo3::exceptions::PyIndexError::new_err(format!(
145                "flat_index {flat_index} is out of bounds"
146            )));
147        }
148        match self.dtype_name.as_str() {
149            "float32" | "f32" => {
150                self.data[offset..offset + 4].copy_from_slice(&(value as f32).to_le_bytes());
151            }
152            "float64" | "f64" => {
153                self.data[offset..offset + 8].copy_from_slice(&value.to_le_bytes());
154            }
155            "int32" | "i32" => {
156                self.data[offset..offset + 4].copy_from_slice(&(value as i32).to_le_bytes());
157            }
158            "int64" | "i64" => {
159                self.data[offset..offset + 8].copy_from_slice(&(value as i64).to_le_bytes());
160            }
161            "int8" | "i8" => {
162                self.data[offset] = value as i8 as u8;
163            }
164            "uint8" | "u8" => {
165                self.data[offset] = value as u8;
166            }
167            "bool" | "b" => {
168                self.data[offset] = if value != 0.0 { 1u8 } else { 0u8 };
169            }
170            _ => {}
171        }
172        Ok(())
173    }
174}
175
176/// Resolve the byte size of a dtype name string.
177///
178/// Returns `PyValueError` for unsupported dtype strings.
179fn resolve_itemsize(dtype_name: &str) -> PyResult<usize> {
180    match dtype_name {
181        "float32" | "f32" => Ok(4),
182        "float64" | "f64" => Ok(8),
183        "int32" | "i32" => Ok(4),
184        "int64" | "i64" => Ok(8),
185        "bool" | "b" => Ok(1),
186        "uint8" | "u8" => Ok(1),
187        "int8" | "i8" => Ok(1),
188        _ => Err(pyo3::exceptions::PyValueError::new_err(format!(
189            "unsupported dtype '{dtype_name}'; supported: float32, f32, float64, f64, \
190             int32, i32, int64, i64, bool, b, uint8, u8, int8, i8"
191        ))),
192    }
193}
194
195/// Register the untyped array class into a PyO3 module.
196///
197/// Call this from your `#[pymodule]` init function to expose [`UntypedArray`].
198pub fn register_untyped_module(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
199    m.add_class::<UntypedArray>()?;
200    Ok(())
201}