1use crate::dtype::DType;
5use crate::error::TorshError;
6use crate::shape::Shape;
7
8use std::fmt;
9use std::sync::Arc;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
13pub enum SparseFormat {
14 COO,
18
19 CSR,
23
24 CSC,
28
29 BSR,
32
33 DIA,
36
37 ELL,
40}
41
42impl SparseFormat {
43 pub fn name(self) -> &'static str {
45 match self {
46 Self::COO => "COO",
47 Self::CSR => "CSR",
48 Self::CSC => "CSC",
49 Self::BSR => "BSR",
50 Self::DIA => "DIA",
51 Self::ELL => "ELL",
52 }
53 }
54
55 pub fn supports_row_access(self) -> bool {
57 matches!(self, Self::CSR | Self::BSR)
58 }
59
60 pub fn supports_column_access(self) -> bool {
62 matches!(self, Self::CSC)
63 }
64
65 pub fn is_gpu_friendly(self) -> bool {
67 matches!(self, Self::CSR | Self::ELL | Self::BSR)
68 }
69}
70
71impl fmt::Display for SparseFormat {
72 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
73 write!(f, "{}", self.name())
74 }
75}
76
77#[derive(Debug, Clone)]
79pub struct SparseMetadata {
80 format: SparseFormat,
82
83 nnz: usize,
85
86 sparsity: f32,
88
89 indices_sorted: bool,
91
92 duplicates_summed: bool,
94
95 block_size: Option<(usize, usize)>,
97
98 num_diagonals: Option<usize>,
100
101 ell_width: Option<usize>,
103
104 compression_stats: CompressionStats,
106}
107
108#[derive(Debug, Clone)]
110pub struct CompressionStats {
111 dense_size_bytes: usize,
113
114 sparse_size_bytes: usize,
116
117 compression_ratio: f32,
119
120 #[allow(dead_code)] index_overhead_bytes: usize,
123}
124
125impl SparseMetadata {
126 pub fn new(
128 format: SparseFormat,
129 nnz: usize,
130 total_elements: usize,
131 dense_size_bytes: usize,
132 sparse_size_bytes: usize,
133 ) -> Self {
134 let sparsity = 1.0 - (nnz as f32 / total_elements as f32);
135 let compression_ratio = dense_size_bytes as f32 / sparse_size_bytes as f32;
136
137 Self {
138 format,
139 nnz,
140 sparsity,
141 indices_sorted: false,
142 duplicates_summed: false,
143 block_size: None,
144 num_diagonals: None,
145 ell_width: None,
146 compression_stats: CompressionStats {
147 dense_size_bytes,
148 sparse_size_bytes,
149 compression_ratio,
150 index_overhead_bytes: sparse_size_bytes - (nnz * 4), },
152 }
153 }
154
155 pub fn format(&self) -> SparseFormat {
157 self.format
158 }
159
160 pub fn nnz(&self) -> usize {
162 self.nnz
163 }
164
165 pub fn sparsity(&self) -> f32 {
167 self.sparsity
168 }
169
170 pub fn density(&self) -> f32 {
172 1.0 - self.sparsity
173 }
174
175 pub fn indices_sorted(&self) -> bool {
177 self.indices_sorted
178 }
179
180 pub fn set_indices_sorted(&mut self, sorted: bool) {
182 self.indices_sorted = sorted;
183 }
184
185 pub fn duplicates_summed(&self) -> bool {
187 self.duplicates_summed
188 }
189
190 pub fn set_duplicates_summed(&mut self, summed: bool) {
192 self.duplicates_summed = summed;
193 }
194
195 pub fn block_size(&self) -> Option<(usize, usize)> {
197 self.block_size
198 }
199
200 pub fn set_block_size(&mut self, size: (usize, usize)) {
202 self.block_size = Some(size);
203 }
204
205 pub fn compression_stats(&self) -> &CompressionStats {
207 &self.compression_stats
208 }
209
210 pub fn is_beneficial(&self) -> bool {
212 self.compression_stats.compression_ratio > 1.2 }
214
215 pub fn memory_savings_bytes(&self) -> i64 {
217 self.compression_stats.dense_size_bytes as i64
218 - self.compression_stats.sparse_size_bytes as i64
219 }
220
221 pub fn format_info(&self) -> String {
223 match self.format {
224 SparseFormat::BSR => {
225 if let Some((bm, bn)) = self.block_size {
226 format!("BSR({}x{})", bm, bn)
227 } else {
228 "BSR".to_string()
229 }
230 }
231 SparseFormat::DIA => {
232 if let Some(ndiag) = self.num_diagonals {
233 format!("DIA({})", ndiag)
234 } else {
235 "DIA".to_string()
236 }
237 }
238 SparseFormat::ELL => {
239 if let Some(width) = self.ell_width {
240 format!("ELL({})", width)
241 } else {
242 "ELL".to_string()
243 }
244 }
245 _ => self.format.name().to_string(),
246 }
247 }
248}
249
250impl fmt::Display for SparseMetadata {
251 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
252 write!(
253 f,
254 "SparseMetadata({}, nnz={}, sparsity={:.2}%, compression={:.1}x)",
255 self.format_info(),
256 self.nnz,
257 self.sparsity * 100.0,
258 self.compression_stats.compression_ratio
259 )
260 }
261}
262
263#[derive(Debug, Clone)]
265pub struct CooIndices {
266 pub rows: Vec<usize>,
268
269 pub cols: Vec<usize>,
271
272 pub extra_dims: Vec<Vec<usize>>,
274}
275
276impl CooIndices {
277 pub fn new_2d(rows: Vec<usize>, cols: Vec<usize>) -> Self {
279 assert_eq!(
280 rows.len(),
281 cols.len(),
282 "Row and column indices must have same length"
283 );
284
285 Self {
286 rows,
287 cols,
288 extra_dims: Vec::new(),
289 }
290 }
291
292 pub fn new_nd(indices: Vec<Vec<usize>>) -> Self {
294 let nnz = indices.first().map_or(0, |dim| dim.len());
295
296 for (i, dim_indices) in indices.iter().enumerate() {
298 assert_eq!(
299 dim_indices.len(),
300 nnz,
301 "Dimension {} indices length mismatch: expected {}, got {}",
302 i,
303 nnz,
304 dim_indices.len()
305 );
306 }
307
308 if indices.len() < 2 {
309 panic!("N-D tensor must have at least 2 dimensions");
310 }
311
312 Self {
313 rows: indices[0].clone(),
314 cols: indices[1].clone(),
315 extra_dims: if indices.len() > 2 {
316 indices[2..].to_vec()
317 } else {
318 Vec::new()
319 },
320 }
321 }
322
323 pub fn nnz(&self) -> usize {
325 self.rows.len()
326 }
327
328 pub fn ndim(&self) -> usize {
330 2 + self.extra_dims.len()
331 }
332
333 pub fn is_sorted(&self) -> bool {
335 for i in 1..self.rows.len() {
336 if self.rows[i] < self.rows[i - 1] {
337 return false;
338 }
339 if self.rows[i] == self.rows[i - 1] && self.cols[i] < self.cols[i - 1] {
340 return false;
341 }
342 }
343 true
344 }
345
346 pub fn sort(&mut self) -> Vec<usize> {
348 let mut perm: Vec<usize> = (0..self.nnz()).collect();
349
350 perm.sort_by(|&a, &b| {
352 match self.rows[a].cmp(&self.rows[b]) {
354 std::cmp::Ordering::Equal => {
355 match self.cols[a].cmp(&self.cols[b]) {
357 std::cmp::Ordering::Equal => {
358 for dim_indices in &self.extra_dims {
360 match dim_indices[a].cmp(&dim_indices[b]) {
361 std::cmp::Ordering::Equal => continue,
362 other => return other,
363 }
364 }
365 std::cmp::Ordering::Equal
366 }
367 other => other,
368 }
369 }
370 other => other,
371 }
372 });
373
374 let orig_rows = self.rows.clone();
376 let orig_cols = self.cols.clone();
377 let orig_extra: Vec<_> = self.extra_dims.clone();
378
379 for (i, &p) in perm.iter().enumerate() {
380 self.rows[i] = orig_rows[p];
381 self.cols[i] = orig_cols[p];
382 for (dim_idx, orig_dim) in orig_extra.iter().enumerate() {
383 self.extra_dims[dim_idx][i] = orig_dim[p];
384 }
385 }
386
387 perm
388 }
389}
390
391#[derive(Debug, Clone)]
393pub struct CsrIndices {
394 pub row_ptrs: Vec<usize>,
396
397 pub col_indices: Vec<usize>,
399}
400
401impl CsrIndices {
402 pub fn new(row_ptrs: Vec<usize>, col_indices: Vec<usize>) -> Self {
404 let nnz = col_indices.len();
406 let _nrows = row_ptrs.len().saturating_sub(1);
407
408 assert_eq!(
409 *row_ptrs.last().unwrap_or(&0),
410 nnz,
411 "Last row pointer must equal nnz"
412 );
413
414 for i in 1..row_ptrs.len() {
416 assert!(
417 row_ptrs[i] >= row_ptrs[i - 1],
418 "Row pointers must be non-decreasing"
419 );
420 }
421
422 Self {
423 row_ptrs,
424 col_indices,
425 }
426 }
427
428 pub fn from_coo(coo: &CooIndices, nrows: usize) -> Self {
430 let _nnz = coo.nnz();
431 let mut row_ptrs = vec![0; nrows + 1];
432
433 for &row in &coo.rows {
435 if row < nrows {
436 row_ptrs[row + 1] += 1;
437 }
438 }
439
440 for i in 1..=nrows {
442 row_ptrs[i] += row_ptrs[i - 1];
443 }
444
445 let col_indices = coo.cols.clone();
447
448 Self::new(row_ptrs, col_indices)
449 }
450
451 pub fn nrows(&self) -> usize {
453 self.row_ptrs.len().saturating_sub(1)
454 }
455
456 pub fn nnz(&self) -> usize {
458 self.col_indices.len()
459 }
460
461 pub fn row_range(&self, row: usize) -> Option<std::ops::Range<usize>> {
463 if row >= self.nrows() {
464 return None;
465 }
466 Some(self.row_ptrs[row]..self.row_ptrs[row + 1])
467 }
468}
469
470pub trait SparseStorage: Send + Sync + std::fmt::Debug {
472 fn metadata(&self) -> &SparseMetadata;
474
475 fn nnz(&self) -> usize {
477 self.metadata().nnz()
478 }
479
480 fn format(&self) -> SparseFormat {
482 self.metadata().format()
483 }
484
485 fn is_beneficial(&self) -> bool {
487 self.metadata().is_beneficial()
488 }
489
490 fn to_coo(&self) -> Result<Arc<dyn SparseStorage>, TorshError>;
492
493 fn to_csr(&self) -> Result<Arc<dyn SparseStorage>, TorshError>;
495
496 fn memory_usage(&self) -> usize;
498}
499
500#[derive(Debug)]
502pub struct CooStorage {
503 metadata: SparseMetadata,
504 indices: CooIndices,
505 values: Vec<u8>, dtype: DType,
507 shape: Shape,
508}
509
510impl CooStorage {
511 pub fn new(
513 indices: CooIndices,
514 values: Vec<u8>,
515 dtype: DType,
516 shape: Shape,
517 ) -> Result<Self, TorshError> {
518 let nnz = indices.nnz();
519 let expected_value_size = nnz * dtype.size();
520
521 if values.len() != expected_value_size {
522 return Err(TorshError::InvalidArgument(format!(
523 "Value buffer size mismatch: expected {}, got {}",
524 expected_value_size,
525 values.len()
526 )));
527 }
528
529 let total_elements: usize = shape.dims().iter().product();
530 let dense_size = total_elements * dtype.size();
531 let sparse_size = values.len() + indices.rows.len() * 8 + indices.cols.len() * 8; let metadata = SparseMetadata::new(
534 SparseFormat::COO,
535 nnz,
536 total_elements,
537 dense_size,
538 sparse_size,
539 );
540
541 Ok(Self {
542 metadata,
543 indices,
544 values,
545 dtype,
546 shape,
547 })
548 }
549
550 pub fn indices(&self) -> &CooIndices {
552 &self.indices
553 }
554
555 pub fn indices_mut(&mut self) -> &mut CooIndices {
557 &mut self.indices
558 }
559
560 pub fn values_bytes(&self) -> &[u8] {
562 &self.values
563 }
564
565 pub fn dtype(&self) -> DType {
567 self.dtype
568 }
569
570 pub fn shape(&self) -> &Shape {
572 &self.shape
573 }
574}
575
576impl SparseStorage for CooStorage {
577 fn metadata(&self) -> &SparseMetadata {
578 &self.metadata
579 }
580
581 fn to_coo(&self) -> Result<Arc<dyn SparseStorage>, TorshError> {
582 Ok(Arc::new(CooStorage {
584 metadata: self.metadata.clone(),
585 indices: self.indices.clone(),
586 values: self.values.clone(),
587 dtype: self.dtype,
588 shape: self.shape.clone(),
589 }))
590 }
591
592 fn to_csr(&self) -> Result<Arc<dyn SparseStorage>, TorshError> {
593 if self.shape.ndim() != 2 {
594 return Err(TorshError::InvalidArgument(
595 "CSR format only supports 2D tensors".to_string(),
596 ));
597 }
598
599 let nrows = self.shape.dims()[0];
600 let csr_indices = CsrIndices::from_coo(&self.indices, nrows);
601
602 Ok(Arc::new(CsrStorage {
603 metadata: {
604 let mut meta = self.metadata.clone();
605 meta.format = SparseFormat::CSR;
606 meta
607 },
608 indices: csr_indices,
609 values: self.values.clone(),
610 dtype: self.dtype,
611 shape: self.shape.clone(),
612 }))
613 }
614
615 fn memory_usage(&self) -> usize {
616 self.values.len()
617 + self.indices.rows.len() * std::mem::size_of::<usize>()
618 + self.indices.cols.len() * std::mem::size_of::<usize>()
619 + self
620 .indices
621 .extra_dims
622 .iter()
623 .map(|dim| dim.len() * std::mem::size_of::<usize>())
624 .sum::<usize>()
625 }
626}
627
628#[derive(Debug)]
630pub struct CsrStorage {
631 metadata: SparseMetadata,
632 indices: CsrIndices,
633 values: Vec<u8>,
634 dtype: DType,
635 shape: Shape,
636}
637
638impl CsrStorage {
639 pub fn new(
641 indices: CsrIndices,
642 values: Vec<u8>,
643 dtype: DType,
644 shape: Shape,
645 ) -> Result<Self, TorshError> {
646 if shape.ndim() != 2 {
647 return Err(TorshError::InvalidArgument(
648 "CSR format only supports 2D tensors".to_string(),
649 ));
650 }
651
652 let nnz = indices.nnz();
653 let expected_value_size = nnz * dtype.size();
654
655 if values.len() != expected_value_size {
656 return Err(TorshError::InvalidArgument(format!(
657 "Value buffer size mismatch: expected {}, got {}",
658 expected_value_size,
659 values.len()
660 )));
661 }
662
663 let total_elements: usize = shape.dims().iter().product();
664 let dense_size = total_elements * dtype.size();
665 let sparse_size = values.len() + indices.row_ptrs.len() * 8 + indices.col_indices.len() * 8;
666
667 let metadata = SparseMetadata::new(
668 SparseFormat::CSR,
669 nnz,
670 total_elements,
671 dense_size,
672 sparse_size,
673 );
674
675 Ok(Self {
676 metadata,
677 indices,
678 values,
679 dtype,
680 shape,
681 })
682 }
683
684 pub fn indices(&self) -> &CsrIndices {
686 &self.indices
687 }
688}
689
690impl SparseStorage for CsrStorage {
691 fn metadata(&self) -> &SparseMetadata {
692 &self.metadata
693 }
694
695 fn to_coo(&self) -> Result<Arc<dyn SparseStorage>, TorshError> {
696 let mut rows = Vec::with_capacity(self.nnz());
698 let mut cols = Vec::with_capacity(self.nnz());
699
700 for row in 0..self.indices.nrows() {
701 let range = self.indices.row_range(row).unwrap();
702 for col_idx in range {
703 rows.push(row);
704 cols.push(self.indices.col_indices[col_idx]);
705 }
706 }
707
708 let coo_indices = CooIndices::new_2d(rows, cols);
709
710 Ok(Arc::new(CooStorage {
711 metadata: {
712 let mut meta = self.metadata.clone();
713 meta.format = SparseFormat::COO;
714 meta
715 },
716 indices: coo_indices,
717 values: self.values.clone(),
718 dtype: self.dtype,
719 shape: self.shape.clone(),
720 }))
721 }
722
723 fn to_csr(&self) -> Result<Arc<dyn SparseStorage>, TorshError> {
724 Ok(Arc::new(CsrStorage {
726 metadata: self.metadata.clone(),
727 indices: self.indices.clone(),
728 values: self.values.clone(),
729 dtype: self.dtype,
730 shape: self.shape.clone(),
731 }))
732 }
733
734 fn memory_usage(&self) -> usize {
735 self.values.len()
736 + self.indices.row_ptrs.len() * std::mem::size_of::<usize>()
737 + self.indices.col_indices.len() * std::mem::size_of::<usize>()
738 }
739}
740
741pub mod utils {
743 use super::*;
744
745 pub fn analyze_sparsity(data: &[f32], shape: &[usize]) -> SparseAnalysis {
747 let total_elements = data.len();
748 let mut nnz = 0;
749 let mut pattern_info = PatternInfo::default();
750
751 for (idx, &value) in data.iter().enumerate() {
753 if value != 0.0 {
754 nnz += 1;
755 pattern_info.update(idx, shape);
756 }
757 }
758
759 let sparsity = 1.0 - (nnz as f32 / total_elements as f32);
760
761 SparseAnalysis {
762 sparsity,
763 nnz,
764 total_elements,
765 pattern_info,
766 }
767 }
768
769 pub fn recommend_format(analysis: &SparseAnalysis, shape: &[usize]) -> FormatRecommendation {
771 let sparsity = analysis.sparsity;
772 let nnz = analysis.nnz;
773
774 if sparsity < 0.5 {
776 return FormatRecommendation {
777 format: None, reason: "Low sparsity, dense representation more efficient".to_string(),
779 confidence: 0.9,
780 };
781 }
782
783 if shape.len() == 2 {
784 let (nrows, ncols) = (shape[0], shape[1]);
786
787 if analysis.pattern_info.has_structured_rows {
788 return FormatRecommendation {
789 format: Some(SparseFormat::CSR),
790 reason: "Good row locality, CSR optimal for row-wise operations".to_string(),
791 confidence: 0.8,
792 };
793 }
794
795 if analysis.pattern_info.has_structured_cols {
796 return FormatRecommendation {
797 format: Some(SparseFormat::CSC),
798 reason: "Good column locality, CSC optimal for column-wise operations"
799 .to_string(),
800 confidence: 0.8,
801 };
802 }
803
804 if nnz < (nrows + ncols) * 10 {
805 return FormatRecommendation {
806 format: Some(SparseFormat::COO),
807 reason: "Very sparse matrix, COO has lowest overhead".to_string(),
808 confidence: 0.7,
809 };
810 }
811
812 return FormatRecommendation {
813 format: Some(SparseFormat::CSR),
814 reason: "General 2D sparse matrix, CSR is default choice".to_string(),
815 confidence: 0.6,
816 };
817 }
818
819 FormatRecommendation {
821 format: Some(SparseFormat::COO),
822 reason: "Multi-dimensional tensor, COO supports arbitrary dimensions".to_string(),
823 confidence: 0.8,
824 }
825 }
826
827 pub fn densify_to_sparse<T>(
829 data: &[T],
830 shape: &Shape,
831 dtype: DType,
832 threshold: Option<f64>,
833 ) -> Result<Arc<dyn SparseStorage>, TorshError>
834 where
835 T: Clone + PartialEq + Into<f64> + Default,
836 {
837 let threshold = threshold.unwrap_or(1e-12);
838 let zero = T::default();
839
840 let mut indices = Vec::new();
842 let mut values = Vec::new();
843
844 for (linear_idx, value) in data.iter().enumerate() {
845 let abs_val = value.clone().into().abs();
846 if abs_val > threshold && *value != zero {
847 let multi_idx = linear_to_multidim(linear_idx, shape.dims());
849 indices.push(multi_idx);
850 values.push(value.clone());
851 }
852 }
853
854 if indices.is_empty() {
855 return Err(TorshError::InvalidArgument(
856 "No non-zero elements found".to_string(),
857 ));
858 }
859
860 let value_bytes: Vec<u8> = values
862 .iter()
863 .flat_map(|v| {
864 let val_f64 = v.clone().into();
865 val_f64.to_ne_bytes()
866 })
867 .collect();
868
869 let dims = shape.dims();
871 match dims.len() {
872 1 => {
873 let rows: Vec<usize> = indices.iter().map(|idx| idx[0]).collect();
874 let cols = vec![0; rows.len()]; let coo_indices = CooIndices::new_2d(rows, cols);
876 CooStorage::new(coo_indices, value_bytes, dtype, shape.clone())
877 .map(|storage| Arc::new(storage) as Arc<dyn SparseStorage>)
878 }
879 2 => {
880 let rows: Vec<usize> = indices.iter().map(|idx| idx[0]).collect();
881 let cols: Vec<usize> = indices.iter().map(|idx| idx[1]).collect();
882 let coo_indices = CooIndices::new_2d(rows, cols);
883 CooStorage::new(coo_indices, value_bytes, dtype, shape.clone())
884 .map(|storage| Arc::new(storage) as Arc<dyn SparseStorage>)
885 }
886 _ => {
887 let transposed_indices: Vec<Vec<usize>> = (0..dims.len())
888 .map(|dim| indices.iter().map(|idx| idx[dim]).collect())
889 .collect();
890 let coo_indices = CooIndices::new_nd(transposed_indices);
891 CooStorage::new(coo_indices, value_bytes, dtype, shape.clone())
892 .map(|storage| Arc::new(storage) as Arc<dyn SparseStorage>)
893 }
894 }
895 }
896
897 fn linear_to_multidim(linear_idx: usize, shape: &[usize]) -> Vec<usize> {
899 let mut result = Vec::with_capacity(shape.len());
900 let mut remaining = linear_idx;
901
902 for &dim_size in shape.iter().rev() {
903 result.push(remaining % dim_size);
904 remaining /= dim_size;
905 }
906
907 result.reverse();
908 result
909 }
910
911 #[derive(Debug, Clone)]
913 pub struct SparseAnalysis {
914 pub sparsity: f32,
915 pub nnz: usize,
916 pub total_elements: usize,
917 pub pattern_info: PatternInfo,
918 }
919
920 #[derive(Debug, Clone, Default)]
922 pub struct PatternInfo {
923 pub has_structured_rows: bool,
924 pub has_structured_cols: bool,
925 pub has_diagonal_structure: bool,
926 pub has_block_structure: bool,
927 pub block_size: Option<(usize, usize)>,
928 }
929
930 impl PatternInfo {
931 fn update(&mut self, idx: usize, shape: &[usize]) {
932 if shape.len() == 2 {
935 let (_nrows, ncols) = (shape[0], shape[1]);
936 let row = idx / ncols;
937 let col = idx % ncols;
938
939 if row == col {
941 self.has_diagonal_structure = true;
942 }
943
944 if row.is_multiple_of(4) && col.is_multiple_of(4) {
946 self.has_block_structure = true;
947 self.block_size = Some((4, 4));
948 }
949 }
950 }
951 }
952
953 #[derive(Debug, Clone)]
955 pub struct FormatRecommendation {
956 pub format: Option<SparseFormat>,
957 pub reason: String,
958 pub confidence: f32, }
960}
961
962#[cfg(test)]
963mod tests {
964 use super::*;
965 use crate::shape::Shape;
966
967 #[test]
968 fn test_sparse_metadata_creation() {
969 let metadata = SparseMetadata::new(
970 SparseFormat::COO,
971 1000, 10000, 40000, 8000, );
976
977 assert_eq!(metadata.format(), SparseFormat::COO);
978 assert_eq!(metadata.nnz(), 1000);
979 assert_eq!(metadata.sparsity(), 0.9); assert!(metadata.is_beneficial()); }
982
983 #[test]
984 fn test_coo_indices_creation() {
985 let rows = vec![0, 1, 2, 1];
986 let cols = vec![1, 0, 2, 2];
987
988 let indices = CooIndices::new_2d(rows.clone(), cols.clone());
989
990 assert_eq!(indices.nnz(), 4);
991 assert_eq!(indices.ndim(), 2);
992 assert_eq!(indices.rows, rows);
993 assert_eq!(indices.cols, cols);
994 }
995
996 #[test]
997 fn test_coo_indices_sorting() {
998 let mut indices = CooIndices::new_2d(
999 vec![2, 1, 0, 1], vec![0, 2, 1, 0], );
1002
1003 assert!(!indices.is_sorted());
1004
1005 let _perm = indices.sort();
1006
1007 assert_eq!(indices.rows, vec![0, 1, 1, 2]);
1009 assert_eq!(indices.cols, vec![1, 0, 2, 0]);
1010 assert!(indices.is_sorted());
1011 }
1012
1013 #[test]
1014 fn test_csr_from_coo() {
1015 let coo_indices = CooIndices::new_2d(
1016 vec![0, 0, 1, 2, 2], vec![1, 2, 0, 1, 2], );
1019
1020 let csr_indices = CsrIndices::from_coo(&coo_indices, 3);
1021
1022 assert_eq!(csr_indices.nrows(), 3);
1023 assert_eq!(csr_indices.nnz(), 5);
1024 assert_eq!(csr_indices.row_ptrs, vec![0, 2, 3, 5]);
1025 assert_eq!(csr_indices.col_indices, vec![1, 2, 0, 1, 2]);
1026 }
1027
1028 #[test]
1029 fn test_coo_storage_creation() {
1030 let indices = CooIndices::new_2d(vec![0, 1], vec![1, 0]);
1031 let values = [1.0_f32.to_ne_bytes(), 2.0_f32.to_ne_bytes()].concat();
1032 let shape = Shape::new(vec![2, 2]);
1033
1034 let storage = CooStorage::new(indices, values, DType::F32, shape).unwrap();
1035
1036 assert_eq!(storage.nnz(), 2);
1037 assert_eq!(storage.format(), SparseFormat::COO);
1038 assert_eq!(storage.dtype(), DType::F32);
1039 }
1040
1041 #[test]
1042 fn test_format_conversion() {
1043 let indices = CooIndices::new_2d(vec![0, 1], vec![1, 0]);
1044 let values = [1.0_f32.to_ne_bytes(), 2.0_f32.to_ne_bytes()].concat();
1045 let shape = Shape::new(vec![2, 2]);
1046
1047 let coo_storage = CooStorage::new(indices, values, DType::F32, shape).unwrap();
1048
1049 let csr_storage = coo_storage.to_csr().unwrap();
1051 assert_eq!(csr_storage.format(), SparseFormat::CSR);
1052
1053 let coo_again = csr_storage.to_coo().unwrap();
1055 assert_eq!(coo_again.format(), SparseFormat::COO);
1056 }
1057
1058 #[test]
1059 fn test_sparsity_analysis() {
1060 let data = vec![0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0];
1061 let shape = vec![3, 3];
1062
1063 let analysis = utils::analyze_sparsity(&data, &shape);
1064
1065 assert_eq!(analysis.nnz, 3);
1066 assert_eq!(analysis.total_elements, 9);
1067 assert!((analysis.sparsity - 2.0 / 3.0).abs() < 1e-6);
1068 }
1069
1070 #[test]
1071 fn test_format_recommendation() {
1072 let analysis = utils::SparseAnalysis {
1074 sparsity: 0.9,
1075 nnz: 100,
1076 total_elements: 1000,
1077 pattern_info: utils::PatternInfo::default(),
1078 };
1079
1080 let shape = vec![100, 10];
1081 let recommendation = utils::recommend_format(&analysis, &shape);
1082
1083 assert!(recommendation.format.is_some());
1084 assert!(recommendation.confidence > 0.0);
1085 }
1086}