1use pyo3::prelude::*;
12
13#[pyclass(name = "UntypedArray")]
18pub struct UntypedArray {
19 data: Vec<u8>,
21 dtype_name: String,
23 shape: Vec<usize>,
25 itemsize: usize,
27}
28
29#[pymethods]
30impl UntypedArray {
31 #[new]
39 pub fn new(shape: Vec<usize>, dtype_name: String) -> PyResult<Self> {
40 let itemsize = resolve_itemsize(&dtype_name)?;
41 let n: usize = shape.iter().product::<usize>().max(1);
42 Ok(Self {
43 data: vec![0u8; n * itemsize],
44 dtype_name,
45 shape,
46 itemsize,
47 })
48 }
49
50 pub fn dtype_name(&self) -> &str {
52 &self.dtype_name
53 }
54
55 pub fn itemsize(&self) -> usize {
57 self.itemsize
58 }
59
60 pub fn shape(&self) -> Vec<usize> {
62 self.shape.clone()
63 }
64
65 pub fn ndim(&self) -> usize {
67 self.shape.len()
68 }
69
70 pub fn nbytes(&self) -> usize {
72 self.data.len()
73 }
74
75 pub fn size(&self) -> usize {
77 self.shape.iter().product()
78 }
79
80 pub fn is_floating(&self) -> bool {
82 matches!(
83 self.dtype_name.as_str(),
84 "float32" | "f32" | "float64" | "f64"
85 )
86 }
87
88 pub fn is_integer(&self) -> bool {
90 matches!(
91 self.dtype_name.as_str(),
92 "int32" | "i32" | "int64" | "i64" | "int8" | "i8" | "uint8" | "u8"
93 )
94 }
95
96 pub fn read_as_f64(&self, flat_index: usize) -> PyResult<f64> {
100 let offset = flat_index * self.itemsize;
101 if offset + self.itemsize > self.data.len() {
102 return Err(pyo3::exceptions::PyIndexError::new_err(format!(
103 "flat_index {flat_index} is out of bounds"
104 )));
105 }
106 let value = match self.dtype_name.as_str() {
107 "float32" | "f32" => {
108 let bytes: [u8; 4] = self.data[offset..offset + 4].try_into().map_err(|_| {
109 pyo3::exceptions::PyValueError::new_err("slice conversion error (f32)")
110 })?;
111 f32::from_le_bytes(bytes) as f64
112 }
113 "float64" | "f64" => {
114 let bytes: [u8; 8] = self.data[offset..offset + 8].try_into().map_err(|_| {
115 pyo3::exceptions::PyValueError::new_err("slice conversion error (f64)")
116 })?;
117 f64::from_le_bytes(bytes)
118 }
119 "int32" | "i32" => {
120 let bytes: [u8; 4] = self.data[offset..offset + 4].try_into().map_err(|_| {
121 pyo3::exceptions::PyValueError::new_err("slice conversion error (i32)")
122 })?;
123 i32::from_le_bytes(bytes) as f64
124 }
125 "int64" | "i64" => {
126 let bytes: [u8; 8] = self.data[offset..offset + 8].try_into().map_err(|_| {
127 pyo3::exceptions::PyValueError::new_err("slice conversion error (i64)")
128 })?;
129 i64::from_le_bytes(bytes) as f64
130 }
131 "int8" | "i8" => self.data[offset] as i8 as f64,
132 "uint8" | "u8" | "bool" | "b" => self.data[offset] as f64,
133 _ => 0.0,
134 };
135 Ok(value)
136 }
137
138 pub fn write_f64(&mut self, flat_index: usize, value: f64) -> PyResult<()> {
142 let offset = flat_index * self.itemsize;
143 if offset + self.itemsize > self.data.len() {
144 return Err(pyo3::exceptions::PyIndexError::new_err(format!(
145 "flat_index {flat_index} is out of bounds"
146 )));
147 }
148 match self.dtype_name.as_str() {
149 "float32" | "f32" => {
150 self.data[offset..offset + 4].copy_from_slice(&(value as f32).to_le_bytes());
151 }
152 "float64" | "f64" => {
153 self.data[offset..offset + 8].copy_from_slice(&value.to_le_bytes());
154 }
155 "int32" | "i32" => {
156 self.data[offset..offset + 4].copy_from_slice(&(value as i32).to_le_bytes());
157 }
158 "int64" | "i64" => {
159 self.data[offset..offset + 8].copy_from_slice(&(value as i64).to_le_bytes());
160 }
161 "int8" | "i8" => {
162 self.data[offset] = value as i8 as u8;
163 }
164 "uint8" | "u8" => {
165 self.data[offset] = value as u8;
166 }
167 "bool" | "b" => {
168 self.data[offset] = if value != 0.0 { 1u8 } else { 0u8 };
169 }
170 _ => {}
171 }
172 Ok(())
173 }
174}
175
176fn resolve_itemsize(dtype_name: &str) -> PyResult<usize> {
180 match dtype_name {
181 "float32" | "f32" => Ok(4),
182 "float64" | "f64" => Ok(8),
183 "int32" | "i32" => Ok(4),
184 "int64" | "i64" => Ok(8),
185 "bool" | "b" => Ok(1),
186 "uint8" | "u8" => Ok(1),
187 "int8" | "i8" => Ok(1),
188 _ => Err(pyo3::exceptions::PyValueError::new_err(format!(
189 "unsupported dtype '{dtype_name}'; supported: float32, f32, float64, f64, \
190 int32, i32, int64, i64, bool, b, uint8, u8, int8, i8"
191 ))),
192 }
193}
194
195pub fn register_untyped_module(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
199 m.add_class::<UntypedArray>()?;
200 Ok(())
201}