1use ::serde::{Deserialize, Serialize};
15use scirs2_core::ndarray::{Array, Array2, ArrayBase, IxDyn};
16use std::collections::HashMap;
17use std::fs::File;
18use std::io::{BufReader, BufWriter, Write};
19use std::path::Path;
20
21use crate::error::{IoError, Result};
22use bincode::{config, serde as bincode_serde};
23fn bincode_cfg() -> impl bincode::config::Config {
24 config::standard()
25}
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum SerializationFormat {
30 Binary,
32 JSON,
34 MessagePack,
36}
37
38#[allow(dead_code)]
66pub fn serialize_array<P, A, S>(
67 path: P,
68 array: &ArrayBase<S, IxDyn>,
69 format: SerializationFormat,
70) -> Result<()>
71where
72 P: AsRef<Path>,
73 A: Serialize + Clone,
74 S: scirs2_core::ndarray::Data<Elem = A>,
75{
76 let shape = array.shape().to_vec();
78 let data: Vec<A> = array.iter().cloned().collect();
79
80 let serializable = SerializedArray {
82 metadata: ArrayMetadata {
83 shape,
84 dtype: std::any::type_name::<A>().to_string(),
85 order: 'C',
86 metadata: std::collections::HashMap::new(),
87 },
88 data,
89 };
90
91 let file = File::create(path).map_err(|e| IoError::FileError(e.to_string()))?;
92 let mut writer = BufWriter::new(file);
93
94 match format {
95 SerializationFormat::Binary => {
96 let cfg = bincode_cfg();
100 let bytes = bincode_serde::encode_to_vec(&serializable, cfg)
101 .map_err(|e| IoError::SerializationError(e.to_string()))?;
102 let len = bytes.len() as u64;
103 writer
104 .write_all(&len.to_le_bytes())
105 .map_err(|e| IoError::FileError(e.to_string()))?;
106 writer
107 .write_all(&bytes)
108 .map_err(|e| IoError::FileError(e.to_string()))?;
109 }
110 SerializationFormat::JSON => {
111 serde_json::to_writer_pretty(&mut writer, &serializable)
112 .map_err(|e| IoError::SerializationError(e.to_string()))?;
113 }
114 SerializationFormat::MessagePack => {
115 rmp_serde::encode::write(&mut writer, &serializable)
116 .map_err(|e| IoError::SerializationError(e.to_string()))?;
117 }
118 }
119
120 writer
121 .flush()
122 .map_err(|e| IoError::FileError(e.to_string()))?;
123 Ok(())
124}
125
126#[allow(dead_code)]
151pub fn deserialize_array<P, A>(path: P, format: SerializationFormat) -> Result<Array<A, IxDyn>>
152where
153 P: AsRef<Path>,
154 A: for<'de> Deserialize<'de> + Clone,
155{
156 let file = File::open(path).map_err(|e| IoError::FileError(e.to_string()))?;
157 let mut reader = BufReader::new(file);
158
159 let serialized: SerializedArray<A> = match format {
160 SerializationFormat::Binary => {
161 use std::io::Read;
164 let mut buf = Vec::new();
165 reader
166 .read_to_end(&mut buf)
167 .map_err(|e| IoError::FileError(e.to_string()))?;
168 if buf.len() >= 8 {
169 let mut len_bytes = [0u8; 8];
170 len_bytes.copy_from_slice(&buf[0..8]);
171 let declared = u64::from_le_bytes(len_bytes) as usize;
172 if declared <= buf.len() - 8 {
173 let data_slice = &buf[8..8 + declared];
174 let cfg = bincode_cfg();
175 if let Ok((val, _consumed)) =
176 bincode_serde::decode_from_slice::<SerializedArray<A>, _>(data_slice, cfg)
177 {
178 val
179 } else {
180 let cfg = bincode_cfg();
182 let (val, _len): (SerializedArray<A>, usize) =
183 bincode_serde::decode_from_slice(&buf, cfg)
184 .map_err(|e| IoError::DeserializationError(e.to_string()))?;
185 val
186 }
187 } else {
188 let cfg = bincode_cfg();
190 let (val, _len): (SerializedArray<A>, usize) =
191 bincode_serde::decode_from_slice(&buf, cfg)
192 .map_err(|e| IoError::DeserializationError(e.to_string()))?;
193 val
194 }
195 } else {
196 let cfg = bincode_cfg();
198 let (val, _len): (SerializedArray<A>, usize) =
199 bincode_serde::decode_from_slice(&buf, cfg)
200 .map_err(|e| IoError::DeserializationError(e.to_string()))?;
201 val
202 }
203 }
204 SerializationFormat::JSON => serde_json::from_reader(reader)
205 .map_err(|e| IoError::DeserializationError(e.to_string()))?,
206 SerializationFormat::MessagePack => rmp_serde::from_read(reader)
207 .map_err(|e| IoError::DeserializationError(e.to_string()))?,
208 };
209
210 let array = Array::from_shape_vec(IxDyn(&serialized.metadata.shape), serialized.data)
212 .map_err(|e| IoError::FormatError(format!("Failed to reconstruct array: {}", e)))?;
213
214 Ok(array)
215}
216
217#[derive(Debug, Clone, Serialize, Deserialize)]
219pub struct ArrayMetadata {
220 pub shape: Vec<usize>,
222 pub dtype: String,
224 pub order: char,
226 pub metadata: std::collections::HashMap<String, String>,
228}
229
230#[derive(Debug, Clone, Serialize, Deserialize)]
232pub struct SerializedArray<A> {
233 pub metadata: ArrayMetadata,
235 pub data: Vec<A>,
237}
238
239#[allow(dead_code)]
275pub fn serialize_array_with_metadata<P, A, S>(
276 path: P,
277 array: &ArrayBase<S, IxDyn>,
278 metadata: std::collections::HashMap<String, String>,
279 format: SerializationFormat,
280) -> Result<()>
281where
282 P: AsRef<Path>,
283 A: Serialize + Clone,
284 S: scirs2_core::ndarray::Data<Elem = A>,
285{
286 let file = File::create(path).map_err(|e| IoError::FileError(e.to_string()))?;
287 let mut writer = BufWriter::new(file);
288
289 let shape = array.shape().to_vec();
291 let dtype = std::any::type_name::<A>().to_string();
292
293 let array_metadata = ArrayMetadata {
294 shape,
295 dtype,
296 order: 'C', metadata,
298 };
299
300 let serialized = SerializedArray {
302 metadata: array_metadata,
303 data: array.iter().cloned().collect(),
304 };
305
306 match format {
308 SerializationFormat::Binary => {
309 let cfg = bincode_cfg();
310 bincode_serde::encode_into_std_write(&serialized, &mut writer, cfg)
311 .map_err(|e| IoError::SerializationError(e.to_string()))?;
312 }
313 SerializationFormat::JSON => {
314 serde_json::to_writer_pretty(&mut writer, &serialized)
315 .map_err(|e| IoError::SerializationError(e.to_string()))?;
316 }
317 SerializationFormat::MessagePack => {
318 rmp_serde::encode::write(&mut writer, &serialized)
319 .map_err(|e| IoError::SerializationError(e.to_string()))?;
320 }
321 }
322
323 writer
324 .flush()
325 .map_err(|e| IoError::FileError(e.to_string()))?;
326 Ok(())
327}
328
329#[allow(dead_code)]
356pub fn deserialize_array_with_metadata<P, A>(
357 path: P,
358 format: SerializationFormat,
359) -> Result<(Array<A, IxDyn>, std::collections::HashMap<String, String>)>
360where
361 P: AsRef<Path>,
362 A: for<'de> Deserialize<'de> + Clone,
363{
364 let file = File::open(path).map_err(|e| IoError::FileError(e.to_string()))?;
365 let mut reader = BufReader::new(file);
366
367 let serialized: SerializedArray<A> = match format {
368 SerializationFormat::Binary => {
369 let cfg = bincode_cfg();
370 let (val, _len): (SerializedArray<A>, usize) =
371 bincode_serde::decode_from_std_read(&mut reader, cfg)
372 .map_err(|e| IoError::DeserializationError(e.to_string()))?;
373 val
374 }
375 SerializationFormat::JSON => serde_json::from_reader(reader)
376 .map_err(|e| IoError::DeserializationError(e.to_string()))?,
377 SerializationFormat::MessagePack => rmp_serde::from_read(reader)
378 .map_err(|e| IoError::DeserializationError(e.to_string()))?,
379 };
380
381 let shape = serialized.metadata.shape;
383 let data = serialized.data;
384
385 let array = Array::from_shape_vec(IxDyn(&shape), data)
387 .map_err(|e| IoError::FormatError(format!("Invalid shape: {:?}", e)))?;
388
389 Ok((array, serialized.metadata.metadata))
390}
391
392#[allow(dead_code)]
427pub fn serialize_struct<P, T>(path: P, data: &T, format: SerializationFormat) -> Result<()>
428where
429 P: AsRef<Path>,
430 T: Serialize,
431{
432 let file = File::create(path).map_err(|e| IoError::FileError(e.to_string()))?;
433 let mut writer = BufWriter::new(file);
434
435 match format {
436 SerializationFormat::Binary => {
437 let cfg = bincode_cfg();
438 bincode_serde::encode_into_std_write(data, &mut writer, cfg)
439 .map_err(|e| IoError::SerializationError(e.to_string()))?;
440 }
441 SerializationFormat::JSON => {
442 serde_json::to_writer_pretty(&mut writer, data)
443 .map_err(|e| IoError::SerializationError(e.to_string()))?;
444 }
445 SerializationFormat::MessagePack => {
446 rmp_serde::encode::write(&mut writer, data)
447 .map_err(|e| IoError::SerializationError(e.to_string()))?;
448 }
449 }
450
451 writer
452 .flush()
453 .map_err(|e| IoError::FileError(e.to_string()))?;
454 Ok(())
455}
456
457#[allow(dead_code)]
486pub fn deserialize_struct<P, T>(path: P, format: SerializationFormat) -> Result<T>
487where
488 P: AsRef<Path>,
489 T: for<'de> Deserialize<'de>,
490{
491 let file = File::open(path).map_err(|e| IoError::FileError(e.to_string()))?;
492 let mut reader = BufReader::new(file);
493
494 match format {
495 SerializationFormat::Binary => {
496 let cfg = bincode_cfg();
497 let (data, _len): (T, usize) = bincode_serde::decode_from_std_read(&mut reader, cfg)
498 .map_err(|e| IoError::DeserializationError(e.to_string()))?;
499 Ok(data)
500 }
501 SerializationFormat::JSON => {
502 let data = serde_json::from_reader(reader)
503 .map_err(|e| IoError::DeserializationError(e.to_string()))?;
504 Ok(data)
505 }
506 SerializationFormat::MessagePack => {
507 let data = rmp_serde::from_read(reader)
508 .map_err(|e| IoError::DeserializationError(e.to_string()))?;
509 Ok(data)
510 }
511 }
512}
513
514#[derive(Debug, Clone, Serialize, Deserialize)]
516pub struct SparseMatrixCOO<A> {
517 pub rows: usize,
519 pub cols: usize,
521 pub row_indices: Vec<usize>,
523 pub col_indices: Vec<usize>,
525 pub values: Vec<A>,
527 pub metadata: std::collections::HashMap<String, String>,
529}
530
531impl<A> SparseMatrixCOO<A> {
532 pub fn new(rows: usize, cols: usize) -> Self {
534 Self {
535 rows,
536 cols,
537 row_indices: Vec::new(),
538 col_indices: Vec::new(),
539 values: Vec::new(),
540 metadata: std::collections::HashMap::new(),
541 }
542 }
543
544 pub fn push(&mut self, row: usize, col: usize, value: A) {
546 if row < self.rows && col < self.cols {
547 self.row_indices.push(row);
548 self.col_indices.push(col);
549 self.values.push(value);
550 }
551 }
552
553 pub fn nnz(&self) -> usize {
555 self.values.len()
556 }
557}
558
559#[allow(dead_code)]
589pub fn serialize_sparse_matrix<P, A>(
590 path: P,
591 matrix: &SparseMatrixCOO<A>,
592 format: SerializationFormat,
593) -> Result<()>
594where
595 P: AsRef<Path>,
596 A: Serialize,
597{
598 serialize_struct(path, matrix, format)
599}
600
601#[allow(dead_code)]
622pub fn deserialize_sparse_matrix<P, A>(
623 path: P,
624 format: SerializationFormat,
625) -> Result<SparseMatrixCOO<A>>
626where
627 P: AsRef<Path>,
628 A: for<'de> Deserialize<'de>,
629{
630 deserialize_struct(path, format)
631}
632
633#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
635pub enum SparseFormat {
636 COO,
638 CSR,
640 CSC,
642}
643
644#[derive(Debug, Clone, Serialize, Deserialize)]
646pub struct SparseMatrix<A> {
647 pub shape: (usize, usize),
649 pub format: SparseFormat,
651 pub coo_data: SparseMatrixCOO<A>,
653 #[serde(skip)]
655 pub csr_data: Option<SparseMatrixCSR<A>>,
656 #[serde(skip)]
658 pub csc_data: Option<SparseMatrixCSC<A>>,
659 pub metadata: HashMap<String, String>,
661}
662
663#[derive(Debug, Clone, Serialize, Deserialize)]
665pub struct SparseMatrixCSR<A> {
666 pub rows: usize,
668 pub cols: usize,
670 pub row_ptrs: Vec<usize>,
672 pub col_indices: Vec<usize>,
674 pub values: Vec<A>,
676 pub metadata: HashMap<String, String>,
678}
679
680#[derive(Debug, Clone, Serialize, Deserialize)]
682pub struct SparseMatrixCSC<A> {
683 pub rows: usize,
685 pub cols: usize,
687 pub col_ptrs: Vec<usize>,
689 pub row_indices: Vec<usize>,
691 pub values: Vec<A>,
693 pub metadata: HashMap<String, String>,
695}
696
697impl<A: Clone> SparseMatrix<A> {
698 pub fn from_coo(coo: SparseMatrixCOO<A>) -> Self {
700 let shape = (coo.rows, coo.cols);
701 Self {
702 shape,
703 format: SparseFormat::COO,
704 coo_data: coo,
705 csr_data: None,
706 csc_data: None,
707 metadata: HashMap::new(),
708 }
709 }
710
711 pub fn new(rows: usize, cols: usize) -> Self {
713 Self {
714 shape: (rows, cols),
715 format: SparseFormat::COO,
716 coo_data: SparseMatrixCOO::new(rows, cols),
717 csr_data: None,
718 csc_data: None,
719 metadata: HashMap::new(),
720 }
721 }
722
723 pub fn insert(&mut self, row: usize, col: usize, value: A) {
725 self.coo_data.push(row, col, value);
726 self.csr_data = None;
728 self.csc_data = None;
729 }
730
731 pub fn nnz(&self) -> usize {
733 self.coo_data.nnz()
734 }
735
736 pub fn shape(&self) -> (usize, usize) {
738 self.shape
739 }
740
741 pub fn to_csr(&mut self) -> Result<&SparseMatrixCSR<A>>
743 where
744 A: Clone + Default + PartialEq,
745 {
746 if self.csr_data.is_none() {
747 self.csr_data = Some(self.convert_to_csr()?);
748 }
749 Ok(self.csr_data.as_ref().unwrap())
750 }
751
752 pub fn to_csc(&mut self) -> Result<&SparseMatrixCSC<A>>
754 where
755 A: Clone + Default + PartialEq,
756 {
757 if self.csc_data.is_none() {
758 self.csc_data = Some(self.convert_to_csc()?);
759 }
760 Ok(self.csc_data.as_ref().unwrap())
761 }
762
763 fn convert_to_csr(&self) -> Result<SparseMatrixCSR<A>>
765 where
766 A: Clone + Default,
767 {
768 let nnz = self.coo_data.nnz();
769 let rows = self.shape.0;
770
771 if nnz == 0 {
772 return Ok(SparseMatrixCSR {
773 rows,
774 cols: self.shape.1,
775 row_ptrs: vec![0; rows + 1],
776 col_indices: Vec::new(),
777 values: Vec::new(),
778 metadata: self.metadata.clone(),
779 });
780 }
781
782 let mut triplets: Vec<(usize, usize, A)> = self
784 .coo_data
785 .row_indices
786 .iter()
787 .zip(self.coo_data.col_indices.iter())
788 .zip(self.coo_data.values.iter())
789 .map(|((&r, &c), v)| (r, c, v.clone()))
790 .collect();
791
792 triplets.sort_by(|a, b| a.0.cmp(&b.0).then(a.1.cmp(&b.1)));
793
794 let mut row_ptrs = vec![0; rows + 1];
796 let mut col_indices = Vec::with_capacity(nnz);
797 let mut values = Vec::with_capacity(nnz);
798
799 let mut current_row = 0;
800
801 for (i, (row, col, val)) in triplets.iter().enumerate() {
802 while current_row < *row {
804 current_row += 1;
805 row_ptrs[current_row] = i;
806 }
807
808 col_indices.push(*col);
809 values.push(val.clone());
810 }
811
812 while current_row < rows {
814 current_row += 1;
815 row_ptrs[current_row] = nnz;
816 }
817
818 Ok(SparseMatrixCSR {
819 rows,
820 cols: self.shape.1,
821 row_ptrs,
822 col_indices,
823 values,
824 metadata: self.metadata.clone(),
825 })
826 }
827
828 fn convert_to_csc(&self) -> Result<SparseMatrixCSC<A>>
830 where
831 A: Clone + Default,
832 {
833 let nnz = self.coo_data.nnz();
834 let cols = self.shape.1;
835
836 if nnz == 0 {
837 return Ok(SparseMatrixCSC {
838 rows: self.shape.0,
839 cols,
840 col_ptrs: vec![0; cols + 1],
841 row_indices: Vec::new(),
842 values: Vec::new(),
843 metadata: self.metadata.clone(),
844 });
845 }
846
847 let mut triplets: Vec<(usize, usize, A)> = self
849 .coo_data
850 .row_indices
851 .iter()
852 .zip(self.coo_data.col_indices.iter())
853 .zip(self.coo_data.values.iter())
854 .map(|((&r, &c), v)| (r, c, v.clone()))
855 .collect();
856
857 triplets.sort_by(|a, b| a.1.cmp(&b.1).then(a.0.cmp(&b.0)));
858
859 let mut col_ptrs = vec![0; cols + 1];
861 let mut row_indices = Vec::with_capacity(nnz);
862 let mut values = Vec::with_capacity(nnz);
863
864 let mut current_col = 0;
865
866 for (i, (row, col, val)) in triplets.iter().enumerate() {
867 while current_col < *col {
869 current_col += 1;
870 col_ptrs[current_col] = i;
871 }
872
873 row_indices.push(*row);
874 values.push(val.clone());
875 }
876
877 while current_col < cols {
879 current_col += 1;
880 col_ptrs[current_col] = nnz;
881 }
882
883 Ok(SparseMatrixCSC {
884 rows: self.shape.0,
885 cols,
886 col_ptrs,
887 row_indices,
888 values,
889 metadata: self.metadata.clone(),
890 })
891 }
892
893 pub fn to_dense(&self) -> Array2<A>
895 where
896 A: Clone + Default,
897 {
898 let mut dense = Array2::default(self.shape);
899
900 for ((row, col), value) in self
901 .coo_data
902 .row_indices
903 .iter()
904 .zip(self.coo_data.col_indices.iter())
905 .zip(self.coo_data.values.iter())
906 {
907 dense[[*row, *col]] = value.clone();
908 }
909
910 dense
911 }
912
913 pub fn sparsity(&self) -> f64 {
915 let total_elements = self.shape.0 * self.shape.1;
916 if total_elements == 0 {
917 0.0
918 } else {
919 1.0 - (self.nnz() as f64 / total_elements as f64)
920 }
921 }
922
923 pub fn memory_usage(&self) -> usize {
925 let coo_size = self.coo_data.values.len()
926 * (std::mem::size_of::<A>() + 2 * std::mem::size_of::<usize>());
927
928 let csr_size = if let Some(ref csr) = self.csr_data {
929 csr.values.len() * std::mem::size_of::<A>()
930 + csr.col_indices.len() * std::mem::size_of::<usize>()
931 + csr.row_ptrs.len() * std::mem::size_of::<usize>()
932 } else {
933 0
934 };
935
936 let csc_size = if let Some(ref csc) = self.csc_data {
937 csc.values.len() * std::mem::size_of::<A>()
938 + csc.row_indices.len() * std::mem::size_of::<usize>()
939 + csc.col_ptrs.len() * std::mem::size_of::<usize>()
940 } else {
941 0
942 };
943
944 coo_size + csr_size + csc_size
945 }
946}
947
948impl<A: Clone> SparseMatrixCSR<A> {
949 pub fn new(rows: usize, cols: usize) -> Self {
951 Self {
952 rows,
953 cols,
954 row_ptrs: vec![0; rows + 1],
955 col_indices: Vec::new(),
956 values: Vec::new(),
957 metadata: HashMap::new(),
958 }
959 }
960
961 pub fn nnz(&self) -> usize {
963 self.values.len()
964 }
965
966 pub fn shape(&self) -> (usize, usize) {
968 (self.rows, self.cols)
969 }
970
971 pub fn row(&self, row: usize) -> Option<(&[usize], &[A])> {
973 if row >= self.rows {
974 return None;
975 }
976
977 let start = self.row_ptrs[row];
978 let end = self.row_ptrs[row + 1];
979
980 Some((&self.col_indices[start..end], &self.values[start..end]))
981 }
982}
983
984impl<A: Clone> SparseMatrixCSC<A> {
985 pub fn new(rows: usize, cols: usize) -> Self {
987 Self {
988 rows,
989 cols,
990 col_ptrs: vec![0; cols + 1],
991 row_indices: Vec::new(),
992 values: Vec::new(),
993 metadata: HashMap::new(),
994 }
995 }
996
997 pub fn nnz(&self) -> usize {
999 self.values.len()
1000 }
1001
1002 pub fn shape(&self) -> (usize, usize) {
1004 (self.rows, self.cols)
1005 }
1006
1007 pub fn column(&self, col: usize) -> Option<(&[usize], &[A])> {
1009 if col >= self.cols {
1010 return None;
1011 }
1012
1013 let start = self.col_ptrs[col];
1014 let end = self.col_ptrs[col + 1];
1015
1016 Some((&self.row_indices[start..end], &self.values[start..end]))
1017 }
1018}
1019
1020#[allow(dead_code)]
1022pub fn serialize_enhanced_sparse_matrix<P, A>(
1023 path: P,
1024 matrix: &SparseMatrix<A>,
1025 format: SerializationFormat,
1026) -> Result<()>
1027where
1028 P: AsRef<Path>,
1029 A: Serialize,
1030{
1031 serialize_struct(path, matrix, format)
1032}
1033
1034#[allow(dead_code)]
1036pub fn deserialize_enhanced_sparse_matrix<P, A>(
1037 path: P,
1038 format: SerializationFormat,
1039) -> Result<SparseMatrix<A>>
1040where
1041 P: AsRef<Path>,
1042 A: for<'de> Deserialize<'de> + Default,
1043{
1044 deserialize_struct(path, format)
1045}
1046
1047#[allow(dead_code)]
1049pub fn from_matrix_market<A>(mm_matrix: &crate::matrix_market::MMSparseMatrix<A>) -> SparseMatrix<A>
1050where
1051 A: Clone,
1052{
1053 let mut coo = SparseMatrixCOO::new(mm_matrix.rows, mm_matrix.cols);
1054
1055 for entry in &mm_matrix.entries {
1056 coo.push(entry.row, entry.col, entry.value.clone());
1057 }
1058
1059 let mut sparse = SparseMatrix::from_coo(coo);
1060 sparse
1061 .metadata
1062 .insert("source".to_string(), "Matrix Market".to_string());
1063 sparse.metadata.insert(
1064 "format".to_string(),
1065 format!("{:?}", mm_matrix.header.format),
1066 );
1067 sparse.metadata.insert(
1068 "data_type".to_string(),
1069 format!("{:?}", mm_matrix.header.data_type),
1070 );
1071 sparse.metadata.insert(
1072 "symmetry".to_string(),
1073 format!("{:?}", mm_matrix.header.symmetry),
1074 );
1075
1076 sparse
1077}
1078
1079#[allow(dead_code)]
1081pub fn to_matrix_market<A>(sparse: &SparseMatrix<A>) -> crate::matrix_market::MMSparseMatrix<A>
1082where
1083 A: Clone,
1084{
1085 let header = crate::matrix_market::MMHeader {
1086 object: "matrix".to_string(),
1087 format: crate::matrix_market::MMFormat::Coordinate,
1088 data_type: crate::matrix_market::MMDataType::Real, symmetry: crate::matrix_market::MMSymmetry::General, comments: vec!["Converted from enhanced _sparse matrix".to_string()],
1091 };
1092
1093 let entries = sparse
1094 .coo_data
1095 .row_indices
1096 .iter()
1097 .zip(sparse.coo_data.col_indices.iter())
1098 .zip(sparse.coo_data.values.iter())
1099 .map(|((&row, &col), value)| crate::matrix_market::SparseEntry {
1100 row,
1101 col,
1102 value: value.clone(),
1103 })
1104 .collect();
1105
1106 crate::matrix_market::MMSparseMatrix {
1107 header,
1108 rows: sparse.shape.0,
1109 cols: sparse.shape.1,
1110 nnz: sparse.nnz(),
1111 entries,
1112 }
1113}
1114
1115pub mod sparse_ops {
1117 use super::*;
1118
1119 pub fn add_coo<A>(a: &SparseMatrixCOO<A>, b: &SparseMatrixCOO<A>) -> Result<SparseMatrixCOO<A>>
1121 where
1122 A: Clone + std::ops::Add<Output = A> + Default + PartialEq,
1123 {
1124 if a.rows != b.rows || a.cols != b.cols {
1125 return Err(IoError::ValidationError(
1126 "Matrix dimensions must match".to_string(),
1127 ));
1128 }
1129
1130 let mut result = SparseMatrixCOO::new(a.rows, a.cols);
1131 let mut indices_map: HashMap<(usize, usize), A> = HashMap::new();
1132
1133 for ((row, col), value) in a
1135 .row_indices
1136 .iter()
1137 .zip(a.col_indices.iter())
1138 .zip(a.values.iter())
1139 {
1140 indices_map.insert((*row, *col), value.clone());
1141 }
1142
1143 for ((row, col), value) in b
1145 .row_indices
1146 .iter()
1147 .zip(b.col_indices.iter())
1148 .zip(b.values.iter())
1149 {
1150 let key = (*row, *col);
1151 if let Some(existing) = indices_map.get(&key) {
1152 indices_map.insert(key, existing.clone() + value.clone());
1153 } else {
1154 indices_map.insert(key, value.clone());
1155 }
1156 }
1157
1158 for ((row, col), value) in indices_map {
1160 if value != A::default() {
1161 result.push(row, col, value);
1163 }
1164 }
1165
1166 Ok(result)
1167 }
1168
1169 pub fn csr_matvec<A>(matrix: &SparseMatrixCSR<A>, vector: &[A]) -> Result<Vec<A>>
1171 where
1172 A: Clone + std::ops::Add<Output = A> + std::ops::Mul<Output = A> + Default,
1173 {
1174 if vector.len() != matrix.cols {
1175 return Err(IoError::ValidationError(
1176 "Vector dimension must match _matrix columns".to_string(),
1177 ));
1178 }
1179
1180 let mut result = vec![A::default(); matrix.rows];
1181
1182 for (row, result_elem) in result.iter_mut().enumerate() {
1183 let start = matrix.row_ptrs[row];
1184 let end = matrix.row_ptrs[row + 1];
1185
1186 let mut sum = A::default();
1187 for i in start..end {
1188 let col = matrix.col_indices[i];
1189 let val = matrix.values[i].clone();
1190 sum = sum + (val * vector[col].clone());
1191 }
1192 *result_elem = sum;
1193 }
1194
1195 Ok(result)
1196 }
1197
1198 pub fn transpose_coo<A>(matrix: &SparseMatrixCOO<A>) -> SparseMatrixCOO<A>
1200 where
1201 A: Clone,
1202 {
1203 let mut result = SparseMatrixCOO::new(matrix.cols, matrix.rows);
1204
1205 for ((row, col), value) in matrix
1206 .row_indices
1207 .iter()
1208 .zip(matrix.col_indices.iter())
1209 .zip(matrix.values.iter())
1210 {
1211 result.push(*col, *row, value.clone());
1212 }
1213
1214 result
1215 }
1216}
1217
1218#[allow(dead_code)]
1222pub fn write_array_json<P, A, S>(path: P, array: &ArrayBase<S, IxDyn>) -> Result<()>
1223where
1224 P: AsRef<Path>,
1225 A: Serialize + Clone,
1226 S: scirs2_core::ndarray::Data<Elem = A>,
1227{
1228 serialize_array::<P, A, S>(path, array, SerializationFormat::JSON)
1229}
1230
1231#[allow(dead_code)]
1233pub fn read_array_json<P, A>(path: P) -> Result<Array<A, IxDyn>>
1234where
1235 P: AsRef<Path>,
1236 A: for<'de> Deserialize<'de> + Clone,
1237{
1238 deserialize_array(path, SerializationFormat::JSON)
1239}
1240
1241#[allow(dead_code)]
1243pub fn write_array_binary<P, A, S>(path: P, array: &ArrayBase<S, IxDyn>) -> Result<()>
1244where
1245 P: AsRef<Path>,
1246 A: Serialize + Clone,
1247 S: scirs2_core::ndarray::Data<Elem = A>,
1248{
1249 serialize_array::<P, A, S>(path, array, SerializationFormat::Binary)
1250}
1251
1252#[allow(dead_code)]
1254pub fn read_array_binary<P, A>(path: P) -> Result<Array<A, IxDyn>>
1255where
1256 P: AsRef<Path>,
1257 A: for<'de> Deserialize<'de> + Clone,
1258{
1259 deserialize_array(path, SerializationFormat::Binary)
1260}
1261
1262#[allow(dead_code)]
1264pub fn write_array_messagepack<P, A, S>(path: P, array: &ArrayBase<S, IxDyn>) -> Result<()>
1265where
1266 P: AsRef<Path>,
1267 A: Serialize + Clone,
1268 S: scirs2_core::ndarray::Data<Elem = A>,
1269{
1270 serialize_array::<P, A, S>(path, array, SerializationFormat::MessagePack)
1271}
1272
1273#[allow(dead_code)]
1275pub fn read_array_messagepack<P, A>(path: P) -> Result<Array<A, IxDyn>>
1276where
1277 P: AsRef<Path>,
1278 A: for<'de> Deserialize<'de> + Clone,
1279{
1280 deserialize_array(path, SerializationFormat::MessagePack)
1281}
1282
1283#[allow(dead_code)]
1304pub fn serialize_array_zero_copy<P, A, S>(
1305 path: P,
1306 array: &ArrayBase<S, IxDyn>,
1307 format: SerializationFormat,
1308) -> Result<()>
1309where
1310 P: AsRef<Path>,
1311 A: Serialize + bytemuck::Pod,
1312 S: scirs2_core::ndarray::Data<Elem = A>,
1313{
1314 if !array.is_standard_layout() {
1315 return Err(IoError::FormatError(
1316 "Array must be in standard layout for zero-copy serialization".to_string(),
1317 ));
1318 }
1319
1320 let file = File::create(&path).map_err(|e| IoError::FileError(e.to_string()))?;
1321 let mut writer = BufWriter::new(file);
1322
1323 let shape = array.shape().to_vec();
1325 let metadata = ArrayMetadata {
1326 shape: shape.clone(),
1327 dtype: std::any::type_name::<A>().to_string(),
1328 order: 'C',
1329 metadata: HashMap::new(),
1330 };
1331
1332 match format {
1333 SerializationFormat::Binary => {
1334 let cfg = bincode_cfg();
1336 bincode_serde::encode_into_std_write(&metadata, &mut writer, cfg)
1337 .map_err(|e| IoError::SerializationError(e.to_string()))?;
1338
1339 if let Some(slice) = array.as_slice() {
1341 let bytes = bytemuck::cast_slice(slice);
1342 writer
1343 .write_all(bytes)
1344 .map_err(|e| IoError::FileError(e.to_string()))?;
1345 }
1346 }
1347 _ => {
1348 return serialize_array(path, array, format);
1351 }
1352 }
1353
1354 writer
1355 .flush()
1356 .map_err(|e| IoError::FileError(e.to_string()))?;
1357 Ok(())
1358}
1359
1360#[allow(dead_code)]
1378pub fn deserialize_array_zero_copy<P>(path: P) -> Result<(ArrayMetadata, memmap2::Mmap)>
1379where
1380 P: AsRef<Path>,
1381{
1382 use std::io::Read;
1383
1384 let mut file = File::open(path).map_err(|e| IoError::FileError(e.to_string()))?;
1385
1386 let mut size_buf = [0u8; 8];
1388 file.read_exact(&mut size_buf)
1389 .map_err(|e| IoError::FileError(e.to_string()))?;
1390 let metadata_size = u64::from_le_bytes(size_buf) as usize;
1391
1392 let mut metadata_buf = vec![0u8; metadata_size];
1394 file.read_exact(&mut metadata_buf)
1395 .map_err(|e| IoError::FileError(e.to_string()))?;
1396
1397 let cfg = bincode_cfg();
1398 let (metadata, _len): (ArrayMetadata, usize) =
1399 bincode_serde::decode_from_slice(&metadata_buf, cfg)
1400 .map_err(|e| IoError::DeserializationError(e.to_string()))?;
1401
1402 let mmap = unsafe {
1404 memmap2::MmapOptions::new()
1405 .offset(8 + metadata_size as u64)
1406 .map(&file)
1407 .map_err(|e| IoError::FileError(e.to_string()))?
1408 };
1409
1410 Ok((metadata, mmap))
1411}