1use pyo3::exceptions::{PyRuntimeError, PyValueError};
7use pyo3::prelude::*;
8use pyo3::types::{PyAny, PyDict};
9
10use scirs2_numpy::{IntoPyArray, PyArray1, PyArray2, PyArrayMethods};
12
13#[allow(unused_imports)]
15use scirs2_core::ndarray::{Array1, Array2};
16
17use scirs2_io::{
19 csv::{read_csv_numeric, write_csv, CsvReaderConfig, CsvWriterConfig},
21 matrix_market::{
23 read_dense_matrix, read_sparse_matrix, write_dense_matrix, write_sparse_matrix, MMDataType,
24 MMDenseMatrix, MMFormat, MMHeader, MMSparseMatrix, MMSymmetry, SparseEntry,
25 },
26 serialize::{deserialize_array, serialize_array, SerializationFormat},
28 wavfile::{read_wav, write_wav},
30};
31
32#[pyfunction]
38#[pyo3(signature = (path, has_header=true, delimiter=","))]
39fn read_csv_py(py: Python, path: &str, has_header: bool, delimiter: &str) -> PyResult<Py<PyAny>> {
40 let config = CsvReaderConfig {
41 has_header,
42 delimiter: delimiter.chars().next().unwrap_or(','),
43 ..Default::default()
44 };
45
46 let (headers, data) = read_csv_numeric(path, Some(config))
47 .map_err(|e| PyRuntimeError::new_err(format!("Failed to read CSV: {}", e)))?;
48
49 let dict = PyDict::new(py);
50 dict.set_item("data", data.into_pyarray(py).unbind())?;
51 dict.set_item("headers", headers)?;
52
53 Ok(dict.into())
54}
55
56#[pyfunction]
58#[pyo3(signature = (path, data, headers=None))]
59fn write_csv_py(
60 path: &str,
61 data: &Bound<'_, PyArray2<f64>>,
62 headers: Option<Vec<String>>,
63) -> PyResult<()> {
64 let binding = data.readonly();
65 let arr = binding.as_array();
66 let arr_owned = arr.to_owned();
67
68 let default_headers: Vec<String> = match &headers {
69 Some(h) => h.clone(),
70 None => (0..arr.ncols()).map(|i| format!("col_{}", i)).collect(),
71 };
72
73 write_csv(
74 path,
75 &arr_owned,
76 Some(&default_headers),
77 None::<CsvWriterConfig>,
78 )
79 .map_err(|e| PyRuntimeError::new_err(format!("Failed to write CSV: {}", e)))?;
80
81 Ok(())
82}
83
84#[pyfunction]
90fn read_matrix_market_sparse_py(py: Python, path: &str) -> PyResult<Py<PyAny>> {
91 let matrix = read_sparse_matrix(path)
92 .map_err(|e| PyRuntimeError::new_err(format!("Failed to read sparse matrix: {}", e)))?;
93
94 let dict = PyDict::new(py);
95 dict.set_item("shape", (matrix.rows, matrix.cols))?;
96 dict.set_item("nnz", matrix.nnz)?;
97
98 let rows: Vec<usize> = matrix.entries.iter().map(|e| e.row).collect();
100 let cols: Vec<usize> = matrix.entries.iter().map(|e| e.col).collect();
101 let data: Vec<f64> = matrix.entries.iter().map(|e| e.value).collect();
102
103 dict.set_item("row", rows)?;
104 dict.set_item("col", cols)?;
105 dict.set_item("data", data)?;
106
107 Ok(dict.into())
108}
109
110#[pyfunction]
112fn write_matrix_market_sparse_py(
113 path: &str,
114 rows: Vec<usize>,
115 cols: Vec<usize>,
116 data: Vec<f64>,
117 shape: (usize, usize),
118) -> PyResult<()> {
119 let entries: Vec<SparseEntry<f64>> = rows
120 .into_iter()
121 .zip(cols)
122 .zip(data)
123 .map(|((r, c), v)| SparseEntry {
124 row: r,
125 col: c,
126 value: v,
127 })
128 .collect();
129
130 let header = MMHeader {
131 object: "matrix".to_string(),
132 format: MMFormat::Coordinate,
133 data_type: MMDataType::Real,
134 symmetry: MMSymmetry::General,
135 comments: Vec::new(),
136 };
137
138 let matrix = MMSparseMatrix {
139 header,
140 rows: shape.0,
141 cols: shape.1,
142 nnz: entries.len(),
143 entries,
144 };
145
146 write_sparse_matrix(path, &matrix)
147 .map_err(|e| PyRuntimeError::new_err(format!("Failed to write sparse matrix: {}", e)))?;
148
149 Ok(())
150}
151
152#[pyfunction]
154fn read_matrix_market_dense_py(py: Python, path: &str) -> PyResult<Py<PyArray2<f64>>> {
155 let matrix = read_dense_matrix(path)
156 .map_err(|e| PyRuntimeError::new_err(format!("Failed to read dense matrix: {}", e)))?;
157
158 Ok(matrix.data.into_pyarray(py).unbind())
159}
160
161#[pyfunction]
163fn write_matrix_market_dense_py(path: &str, data: &Bound<'_, PyArray2<f64>>) -> PyResult<()> {
164 let binding = data.readonly();
165 let arr = binding.as_array();
166
167 let header = MMHeader {
168 object: "matrix".to_string(),
169 format: MMFormat::Array,
170 data_type: MMDataType::Real,
171 symmetry: MMSymmetry::General,
172 comments: Vec::new(),
173 };
174
175 let matrix = MMDenseMatrix {
176 header,
177 rows: arr.nrows(),
178 cols: arr.ncols(),
179 data: arr.to_owned(),
180 };
181
182 write_dense_matrix(path, &matrix)
183 .map_err(|e| PyRuntimeError::new_err(format!("Failed to write dense matrix: {}", e)))?;
184
185 Ok(())
186}
187
188#[pyfunction]
194#[pyo3(signature = (path, data, format="binary"))]
195fn save_array_py(path: &str, data: &Bound<'_, PyArray2<f64>>, format: &str) -> PyResult<()> {
196 let binding = data.readonly();
197 let arr = binding.as_array();
198
199 let ser_format = match format {
200 "binary" => SerializationFormat::Binary,
201 "json" => SerializationFormat::JSON,
202 "messagepack" | "msgpack" => SerializationFormat::MessagePack,
203 _ => return Err(PyValueError::new_err(format!("Unknown format: {}", format))),
204 };
205
206 let arr_dyn = arr.to_owned().into_dyn();
208 serialize_array::<_, f64, _>(path, &arr_dyn, ser_format)
209 .map_err(|e| PyRuntimeError::new_err(format!("Failed to save array: {}", e)))?;
210
211 Ok(())
212}
213
214#[pyfunction]
216#[pyo3(signature = (path, format="binary"))]
217fn load_array_py(py: Python, path: &str, format: &str) -> PyResult<Py<PyArray2<f64>>> {
218 let ser_format = match format {
219 "binary" => SerializationFormat::Binary,
220 "json" => SerializationFormat::JSON,
221 "messagepack" | "msgpack" => SerializationFormat::MessagePack,
222 _ => return Err(PyValueError::new_err(format!("Unknown format: {}", format))),
223 };
224
225 let arr: scirs2_core::ndarray::ArrayD<f64> = deserialize_array(path, ser_format)
226 .map_err(|e| PyRuntimeError::new_err(format!("Failed to load array: {}", e)))?;
227
228 let shape = arr.shape();
230 if shape.len() == 2 {
231 let arr2d = arr
232 .into_dimensionality::<scirs2_core::ndarray::Ix2>()
233 .map_err(|e| PyRuntimeError::new_err(format!("Shape conversion failed: {}", e)))?;
234 Ok(arr2d.into_pyarray(py).unbind())
235 } else {
236 Err(PyValueError::new_err("Array is not 2-dimensional"))
237 }
238}
239
240#[pyfunction]
246fn read_wav_py(py: Python, path: &str) -> PyResult<Py<PyAny>> {
247 let (header, data) = read_wav(path)
248 .map_err(|e| PyRuntimeError::new_err(format!("Failed to read WAV: {}", e)))?;
249
250 let dict = PyDict::new(py);
251 dict.set_item("sample_rate", header.sample_rate)?;
252 dict.set_item("channels", header.channels)?;
253 dict.set_item("bits_per_sample", header.bits_per_sample)?;
254
255 let data_f64: Array1<f64> = data.mapv(|x| x as f64).iter().cloned().collect();
257 dict.set_item("data", data_f64.into_pyarray(py).unbind())?;
258
259 Ok(dict.into())
260}
261
262#[pyfunction]
264fn write_wav_py(path: &str, samplerate: u32, data: &Bound<'_, PyArray1<f64>>) -> PyResult<()> {
265 let binding = data.readonly();
266 let arr = binding.as_array();
267
268 let data_f32: scirs2_core::ndarray::ArrayD<f32> = arr.mapv(|x| x as f32).into_dyn();
270
271 write_wav(path, samplerate, &data_f32)
272 .map_err(|e| PyRuntimeError::new_err(format!("Failed to write WAV: {}", e)))?;
273
274 Ok(())
275}
276
277pub fn register_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
279 m.add_function(wrap_pyfunction!(read_csv_py, m)?)?;
281 m.add_function(wrap_pyfunction!(write_csv_py, m)?)?;
282
283 m.add_function(wrap_pyfunction!(read_matrix_market_sparse_py, m)?)?;
285 m.add_function(wrap_pyfunction!(write_matrix_market_sparse_py, m)?)?;
286 m.add_function(wrap_pyfunction!(read_matrix_market_dense_py, m)?)?;
287 m.add_function(wrap_pyfunction!(write_matrix_market_dense_py, m)?)?;
288
289 m.add_function(wrap_pyfunction!(save_array_py, m)?)?;
291 m.add_function(wrap_pyfunction!(load_array_py, m)?)?;
292
293 m.add_function(wrap_pyfunction!(read_wav_py, m)?)?;
295 m.add_function(wrap_pyfunction!(write_wav_py, m)?)?;
296
297 Ok(())
298}