Skip to main content

scirs2/
io.rs

1//! Python bindings for scirs2-io
2//!
3//! This module provides Python bindings for file I/O operations,
4//! including CSV, MATLAB, HDF5, and other format support.
5
6use pyo3::exceptions::{PyRuntimeError, PyValueError};
7use pyo3::prelude::*;
8use pyo3::types::{PyAny, PyDict};
9
10// NumPy types for Python array interface
11use scirs2_numpy::{IntoPyArray, PyArray1, PyArray2, PyArrayMethods};
12
13// ndarray types from scirs2-core
14#[allow(unused_imports)]
15use scirs2_core::ndarray::{Array1, Array2};
16
17// Direct imports from scirs2-io
18use scirs2_io::{
19    // CSV operations
20    csv::{read_csv_numeric, write_csv, CsvReaderConfig, CsvWriterConfig},
21    // Matrix Market format
22    matrix_market::{
23        read_dense_matrix, read_sparse_matrix, write_dense_matrix, write_sparse_matrix, MMDataType,
24        MMDenseMatrix, MMFormat, MMHeader, MMSparseMatrix, MMSymmetry, SparseEntry,
25    },
26    // Serialization
27    serialize::{deserialize_array, serialize_array, SerializationFormat},
28    // WAV files
29    wavfile::{read_wav, write_wav},
30};
31
32// ========================================
33// CSV OPERATIONS
34// ========================================
35
36/// Read CSV file into array
37#[pyfunction]
38#[pyo3(signature = (path, has_header=true, delimiter=","))]
39fn read_csv_py(py: Python, path: &str, has_header: bool, delimiter: &str) -> PyResult<Py<PyAny>> {
40    let config = CsvReaderConfig {
41        has_header,
42        delimiter: delimiter.chars().next().unwrap_or(','),
43        ..Default::default()
44    };
45
46    let (headers, data) = read_csv_numeric(path, Some(config))
47        .map_err(|e| PyRuntimeError::new_err(format!("Failed to read CSV: {}", e)))?;
48
49    let dict = PyDict::new(py);
50    dict.set_item("data", data.into_pyarray(py).unbind())?;
51    dict.set_item("headers", headers)?;
52
53    Ok(dict.into())
54}
55
56/// Write array to CSV file
57#[pyfunction]
58#[pyo3(signature = (path, data, headers=None))]
59fn write_csv_py(
60    path: &str,
61    data: &Bound<'_, PyArray2<f64>>,
62    headers: Option<Vec<String>>,
63) -> PyResult<()> {
64    let binding = data.readonly();
65    let arr = binding.as_array();
66    let arr_owned = arr.to_owned();
67
68    let default_headers: Vec<String> = match &headers {
69        Some(h) => h.clone(),
70        None => (0..arr.ncols()).map(|i| format!("col_{}", i)).collect(),
71    };
72
73    write_csv(
74        path,
75        &arr_owned,
76        Some(&default_headers),
77        None::<CsvWriterConfig>,
78    )
79    .map_err(|e| PyRuntimeError::new_err(format!("Failed to write CSV: {}", e)))?;
80
81    Ok(())
82}
83
84// ========================================
85// MATRIX MARKET FORMAT
86// ========================================
87
88/// Read Matrix Market sparse matrix
89#[pyfunction]
90fn read_matrix_market_sparse_py(py: Python, path: &str) -> PyResult<Py<PyAny>> {
91    let matrix = read_sparse_matrix(path)
92        .map_err(|e| PyRuntimeError::new_err(format!("Failed to read sparse matrix: {}", e)))?;
93
94    let dict = PyDict::new(py);
95    dict.set_item("shape", (matrix.rows, matrix.cols))?;
96    dict.set_item("nnz", matrix.nnz)?;
97
98    // Extract COO format from SparseEntry structs
99    let rows: Vec<usize> = matrix.entries.iter().map(|e| e.row).collect();
100    let cols: Vec<usize> = matrix.entries.iter().map(|e| e.col).collect();
101    let data: Vec<f64> = matrix.entries.iter().map(|e| e.value).collect();
102
103    dict.set_item("row", rows)?;
104    dict.set_item("col", cols)?;
105    dict.set_item("data", data)?;
106
107    Ok(dict.into())
108}
109
110/// Write sparse matrix in Matrix Market format
111#[pyfunction]
112fn write_matrix_market_sparse_py(
113    path: &str,
114    rows: Vec<usize>,
115    cols: Vec<usize>,
116    data: Vec<f64>,
117    shape: (usize, usize),
118) -> PyResult<()> {
119    let entries: Vec<SparseEntry<f64>> = rows
120        .into_iter()
121        .zip(cols)
122        .zip(data)
123        .map(|((r, c), v)| SparseEntry {
124            row: r,
125            col: c,
126            value: v,
127        })
128        .collect();
129
130    let header = MMHeader {
131        object: "matrix".to_string(),
132        format: MMFormat::Coordinate,
133        data_type: MMDataType::Real,
134        symmetry: MMSymmetry::General,
135        comments: Vec::new(),
136    };
137
138    let matrix = MMSparseMatrix {
139        header,
140        rows: shape.0,
141        cols: shape.1,
142        nnz: entries.len(),
143        entries,
144    };
145
146    write_sparse_matrix(path, &matrix)
147        .map_err(|e| PyRuntimeError::new_err(format!("Failed to write sparse matrix: {}", e)))?;
148
149    Ok(())
150}
151
152/// Read Matrix Market dense matrix
153#[pyfunction]
154fn read_matrix_market_dense_py(py: Python, path: &str) -> PyResult<Py<PyArray2<f64>>> {
155    let matrix = read_dense_matrix(path)
156        .map_err(|e| PyRuntimeError::new_err(format!("Failed to read dense matrix: {}", e)))?;
157
158    Ok(matrix.data.into_pyarray(py).unbind())
159}
160
161/// Write dense matrix in Matrix Market format
162#[pyfunction]
163fn write_matrix_market_dense_py(path: &str, data: &Bound<'_, PyArray2<f64>>) -> PyResult<()> {
164    let binding = data.readonly();
165    let arr = binding.as_array();
166
167    let header = MMHeader {
168        object: "matrix".to_string(),
169        format: MMFormat::Array,
170        data_type: MMDataType::Real,
171        symmetry: MMSymmetry::General,
172        comments: Vec::new(),
173    };
174
175    let matrix = MMDenseMatrix {
176        header,
177        rows: arr.nrows(),
178        cols: arr.ncols(),
179        data: arr.to_owned(),
180    };
181
182    write_dense_matrix(path, &matrix)
183        .map_err(|e| PyRuntimeError::new_err(format!("Failed to write dense matrix: {}", e)))?;
184
185    Ok(())
186}
187
188// ========================================
189// SERIALIZATION
190// ========================================
191
192/// Save array to binary file
193#[pyfunction]
194#[pyo3(signature = (path, data, format="binary"))]
195fn save_array_py(path: &str, data: &Bound<'_, PyArray2<f64>>, format: &str) -> PyResult<()> {
196    let binding = data.readonly();
197    let arr = binding.as_array();
198
199    let ser_format = match format {
200        "binary" => SerializationFormat::Binary,
201        "json" => SerializationFormat::JSON,
202        "messagepack" | "msgpack" => SerializationFormat::MessagePack,
203        _ => return Err(PyValueError::new_err(format!("Unknown format: {}", format))),
204    };
205
206    // Convert to dynamic dimensionality for serialize_array
207    let arr_dyn = arr.to_owned().into_dyn();
208    serialize_array::<_, f64, _>(path, &arr_dyn, ser_format)
209        .map_err(|e| PyRuntimeError::new_err(format!("Failed to save array: {}", e)))?;
210
211    Ok(())
212}
213
214/// Load array from binary file
215#[pyfunction]
216#[pyo3(signature = (path, format="binary"))]
217fn load_array_py(py: Python, path: &str, format: &str) -> PyResult<Py<PyArray2<f64>>> {
218    let ser_format = match format {
219        "binary" => SerializationFormat::Binary,
220        "json" => SerializationFormat::JSON,
221        "messagepack" | "msgpack" => SerializationFormat::MessagePack,
222        _ => return Err(PyValueError::new_err(format!("Unknown format: {}", format))),
223    };
224
225    let arr: scirs2_core::ndarray::ArrayD<f64> = deserialize_array(path, ser_format)
226        .map_err(|e| PyRuntimeError::new_err(format!("Failed to load array: {}", e)))?;
227
228    // Convert to 2D
229    let shape = arr.shape();
230    if shape.len() == 2 {
231        let arr2d = arr
232            .into_dimensionality::<scirs2_core::ndarray::Ix2>()
233            .map_err(|e| PyRuntimeError::new_err(format!("Shape conversion failed: {}", e)))?;
234        Ok(arr2d.into_pyarray(py).unbind())
235    } else {
236        Err(PyValueError::new_err("Array is not 2-dimensional"))
237    }
238}
239
240// ========================================
241// WAV FILE OPERATIONS
242// ========================================
243
244/// Read WAV audio file
245#[pyfunction]
246fn read_wav_py(py: Python, path: &str) -> PyResult<Py<PyAny>> {
247    let (header, data) = read_wav(path)
248        .map_err(|e| PyRuntimeError::new_err(format!("Failed to read WAV: {}", e)))?;
249
250    let dict = PyDict::new(py);
251    dict.set_item("sample_rate", header.sample_rate)?;
252    dict.set_item("channels", header.channels)?;
253    dict.set_item("bits_per_sample", header.bits_per_sample)?;
254
255    // Convert to f64 for consistency
256    let data_f64: Array1<f64> = data.mapv(|x| x as f64).iter().cloned().collect();
257    dict.set_item("data", data_f64.into_pyarray(py).unbind())?;
258
259    Ok(dict.into())
260}
261
262/// Write WAV audio file
263#[pyfunction]
264fn write_wav_py(path: &str, samplerate: u32, data: &Bound<'_, PyArray1<f64>>) -> PyResult<()> {
265    let binding = data.readonly();
266    let arr = binding.as_array();
267
268    // Convert to f32
269    let data_f32: scirs2_core::ndarray::ArrayD<f32> = arr.mapv(|x| x as f32).into_dyn();
270
271    write_wav(path, samplerate, &data_f32)
272        .map_err(|e| PyRuntimeError::new_err(format!("Failed to write WAV: {}", e)))?;
273
274    Ok(())
275}
276
277/// Python module registration
278pub fn register_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
279    // CSV operations
280    m.add_function(wrap_pyfunction!(read_csv_py, m)?)?;
281    m.add_function(wrap_pyfunction!(write_csv_py, m)?)?;
282
283    // Matrix Market format
284    m.add_function(wrap_pyfunction!(read_matrix_market_sparse_py, m)?)?;
285    m.add_function(wrap_pyfunction!(write_matrix_market_sparse_py, m)?)?;
286    m.add_function(wrap_pyfunction!(read_matrix_market_dense_py, m)?)?;
287    m.add_function(wrap_pyfunction!(write_matrix_market_dense_py, m)?)?;
288
289    // Serialization
290    m.add_function(wrap_pyfunction!(save_array_py, m)?)?;
291    m.add_function(wrap_pyfunction!(load_array_py, m)?)?;
292
293    // WAV files
294    m.add_function(wrap_pyfunction!(read_wav_py, m)?)?;
295    m.add_function(wrap_pyfunction!(write_wav_py, m)?)?;
296
297    Ok(())
298}