Skip to main content

scirs2_numpy/
structured.rs

1//! Structured dtype support — record arrays with named, typed fields.
2//!
3//! Mirrors `numpy.dtype([("name", "f64"), ...])` and `numpy.recarray`.
4//! A [`StructuredDtype`] describes the layout of a single record, and a
5//! [`StructuredArray`] stores `n` such records as a flat byte buffer.
6
7use pyo3::prelude::*;
8
9/// Descriptor for a single named field within a structured dtype.
10#[derive(Debug, Clone)]
11#[pyclass(name = "DtypeField", from_py_object)]
12pub struct DtypeField {
13    /// Field name.
14    #[pyo3(get)]
15    pub name: String,
16    /// Element type string, e.g. `"f64"`, `"i32"`, `"bool"`.
17    #[pyo3(get)]
18    pub dtype: String,
19    /// Byte offset of this field within one record.
20    #[pyo3(get)]
21    pub offset: usize,
22}
23
24/// A structured dtype: a record type composed of multiple named fields.
25///
26/// Mirrors `numpy.dtype` with compound (structured) field specifications.
27#[pyclass(name = "StructuredDtype")]
28pub struct StructuredDtype {
29    /// Ordered list of field descriptors.
30    fields: Vec<DtypeField>,
31    /// Total byte size of one record.
32    itemsize: usize,
33}
34
35#[pymethods]
36impl StructuredDtype {
37    /// Build a structured dtype from a list of `(name, dtype_str)` pairs.
38    ///
39    /// Fields are laid out sequentially with no padding.  The first field is at
40    /// byte offset 0; each subsequent field starts immediately after the previous.
41    #[new]
42    pub fn new(field_specs: Vec<(String, String)>) -> PyResult<Self> {
43        let mut offset = 0usize;
44        let mut fields = Vec::with_capacity(field_specs.len());
45        for (name, dtype) in field_specs {
46            let size = dtype_size(&dtype)?;
47            fields.push(DtypeField {
48                name,
49                dtype,
50                offset,
51            });
52            offset += size;
53        }
54        Ok(Self {
55            fields,
56            itemsize: offset,
57        })
58    }
59
60    /// Return an ordered list of field names.
61    pub fn names(&self) -> Vec<String> {
62        self.fields.iter().map(|f| f.name.clone()).collect()
63    }
64
65    /// Return the total byte size of one record.
66    pub fn itemsize(&self) -> usize {
67        self.itemsize
68    }
69
70    /// Return the byte offset of each field (in field order).
71    pub fn offsets(&self) -> Vec<usize> {
72        self.fields.iter().map(|f| f.offset).collect()
73    }
74
75    /// Return the number of fields.
76    pub fn field_count(&self) -> usize {
77        self.fields.len()
78    }
79}
80
81/// Return the byte size for a dtype name string.
82///
83/// Returns a `PyValueError` for unrecognised dtype strings.
84fn dtype_size(dtype: &str) -> PyResult<usize> {
85    match dtype {
86        "f32" | "float32" => Ok(4),
87        "f64" | "float64" => Ok(8),
88        "i32" | "int32" => Ok(4),
89        "i64" | "int64" => Ok(8),
90        "u32" | "uint32" => Ok(4),
91        "u64" | "uint64" => Ok(8),
92        "bool" => Ok(1),
93        "i8" | "int8" => Ok(1),
94        "u8" | "uint8" => Ok(1),
95        _ => Err(pyo3::exceptions::PyValueError::new_err(format!(
96            "unknown dtype '{dtype}'; supported: f32, f64, i32, i64, u32, u64, bool, i8, u8"
97        ))),
98    }
99}
100
101/// A structured array: a flat byte buffer interpreted as `n_records` records,
102/// each described by a [`StructuredDtype`].
103#[pyclass(name = "StructuredArray")]
104pub struct StructuredArray {
105    /// Dtype describing the layout of each record.
106    dtype: StructuredDtype,
107    /// Raw backing buffer; length is `n_records * dtype.itemsize`.
108    data: Vec<u8>,
109    /// Number of records stored.
110    n_records: usize,
111}
112
113#[pymethods]
114impl StructuredArray {
115    /// Construct a zero-initialised structured array.
116    ///
117    /// # Arguments
118    /// * `n_records`   – number of records to allocate.
119    /// * `field_specs` – list of `(name, dtype_str)` pairs defining the record layout.
120    #[new]
121    pub fn new_empty(n_records: usize, field_specs: Vec<(String, String)>) -> PyResult<Self> {
122        let dtype = StructuredDtype::new(field_specs)?;
123        let data = vec![0u8; n_records * dtype.itemsize];
124        Ok(Self {
125            dtype,
126            data,
127            n_records,
128        })
129    }
130
131    /// Return the number of records.
132    pub fn n_records(&self) -> usize {
133        self.n_records
134    }
135
136    /// Return the byte size of one record.
137    pub fn itemsize(&self) -> usize {
138        self.dtype.itemsize
139    }
140
141    /// Read all values for an `f64` field as a `Vec<f64>`.
142    ///
143    /// Returns `PyKeyError` if `field_name` is not found, or `PyTypeError`
144    /// if the field dtype is not `"f64"` / `"float64"`.
145    pub fn get_field_f64(&self, field_name: &str) -> PyResult<Vec<f64>> {
146        let field = self
147            .dtype
148            .fields
149            .iter()
150            .find(|f| f.name == field_name)
151            .ok_or_else(|| {
152                pyo3::exceptions::PyKeyError::new_err(format!("field '{field_name}' not found"))
153            })?;
154        if field.dtype != "f64" && field.dtype != "float64" {
155            return Err(pyo3::exceptions::PyTypeError::new_err(format!(
156                "field '{field_name}' has dtype '{}', not f64",
157                field.dtype
158            )));
159        }
160        let mut result = Vec::with_capacity(self.n_records);
161        for i in 0..self.n_records {
162            let byte_offset = i * self.dtype.itemsize + field.offset;
163            let bytes: [u8; 8] = self.data[byte_offset..byte_offset + 8]
164                .try_into()
165                .map_err(|_| pyo3::exceptions::PyValueError::new_err("slice conversion error"))?;
166            result.push(f64::from_le_bytes(bytes));
167        }
168        Ok(result)
169    }
170
171    /// Write all values for an `f64` field from a `Vec<f64>`.
172    ///
173    /// Returns `PyValueError` if `values.len() != n_records`, `PyKeyError` if the
174    /// field is not found, or `PyTypeError` if the field dtype is not `"f64"` / `"float64"`.
175    pub fn set_field_f64(&mut self, field_name: &str, values: Vec<f64>) -> PyResult<()> {
176        if values.len() != self.n_records {
177            return Err(pyo3::exceptions::PyValueError::new_err(format!(
178                "values length {} does not match n_records {}",
179                values.len(),
180                self.n_records
181            )));
182        }
183        let field = self
184            .dtype
185            .fields
186            .iter()
187            .find(|f| f.name == field_name)
188            .ok_or_else(|| {
189                pyo3::exceptions::PyKeyError::new_err(format!("field '{field_name}' not found"))
190            })?;
191        if field.dtype != "f64" && field.dtype != "float64" {
192            return Err(pyo3::exceptions::PyTypeError::new_err(format!(
193                "field '{field_name}' has dtype '{}', not f64",
194                field.dtype
195            )));
196        }
197        let field_offset = field.offset;
198        let itemsize = self.dtype.itemsize;
199        for (i, &v) in values.iter().enumerate() {
200            let byte_offset = i * itemsize + field_offset;
201            self.data[byte_offset..byte_offset + 8].copy_from_slice(&v.to_le_bytes());
202        }
203        Ok(())
204    }
205
206    /// Return the names of all fields.
207    pub fn field_names(&self) -> Vec<String> {
208        self.dtype.names()
209    }
210}
211
212/// Register structured dtype classes into a PyO3 module.
213///
214/// Call this from your `#[pymodule]` init function to expose `DtypeField`,
215/// `StructuredDtype`, and `StructuredArray`.
216pub fn register_structured_module(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
217    m.add_class::<DtypeField>()?;
218    m.add_class::<StructuredDtype>()?;
219    m.add_class::<StructuredArray>()?;
220    Ok(())
221}