torsh_core/
sparse.rs

1// Sparse tensor metadata and storage formats for ToRSh Core
2// Supports COO (Coordinate), CSR (Compressed Sparse Row), and CSC (Compressed Sparse Column) formats
3
4use crate::dtype::DType;
5use crate::error::TorshError;
6use crate::shape::Shape;
7
8use std::fmt;
9use std::sync::Arc;
10
11/// Sparse tensor storage formats
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
13pub enum SparseFormat {
14    /// Coordinate format: stores (indices, values) pairs
15    /// Memory: O(nnz) for indices + values
16    /// Good for: construction, element access
17    COO,
18
19    /// Compressed Sparse Row format
20    /// Memory: O(nnz) values + O(nnz) column indices + O(rows+1) row pointers
21    /// Good for: matrix-vector multiplication, row access
22    CSR,
23
24    /// Compressed Sparse Column format
25    /// Memory: O(nnz) values + O(nnz) row indices + O(cols+1) column pointers
26    /// Good for: matrix-vector multiplication (transposed), column access
27    CSC,
28
29    /// Block Sparse Row format for structured sparsity
30    /// Good for: GPU acceleration, structured pruning
31    BSR,
32
33    /// Diagonal format for diagonal and band matrices
34    /// Good for: diagonal matrices, finite difference operators
35    DIA,
36
37    /// ELLPack format for GPU-optimized sparse operations
38    /// Good for: GPU kernels with regular sparsity patterns
39    ELL,
40}
41
42impl SparseFormat {
43    /// Get human-readable name
44    pub fn name(self) -> &'static str {
45        match self {
46            Self::COO => "COO",
47            Self::CSR => "CSR",
48            Self::CSC => "CSC",
49            Self::BSR => "BSR",
50            Self::DIA => "DIA",
51            Self::ELL => "ELL",
52        }
53    }
54
55    /// Check if format supports efficient row access
56    pub fn supports_row_access(self) -> bool {
57        matches!(self, Self::CSR | Self::BSR)
58    }
59
60    /// Check if format supports efficient column access
61    pub fn supports_column_access(self) -> bool {
62        matches!(self, Self::CSC)
63    }
64
65    /// Check if format is suitable for GPU operations
66    pub fn is_gpu_friendly(self) -> bool {
67        matches!(self, Self::CSR | Self::ELL | Self::BSR)
68    }
69}
70
71impl fmt::Display for SparseFormat {
72    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
73        write!(f, "{}", self.name())
74    }
75}
76
77/// Sparse tensor metadata containing format and structural information
78#[derive(Debug, Clone)]
79pub struct SparseMetadata {
80    /// Storage format
81    format: SparseFormat,
82
83    /// Number of non-zero elements
84    nnz: usize,
85
86    /// Sparsity ratio (0.0 = dense, 1.0 = all zeros)
87    sparsity: f32,
88
89    /// Whether indices are sorted
90    indices_sorted: bool,
91
92    /// Whether duplicates have been summed
93    duplicates_summed: bool,
94
95    /// Block size for BSR format
96    block_size: Option<(usize, usize)>,
97
98    /// Number of diagonals for DIA format
99    num_diagonals: Option<usize>,
100
101    /// ELLPack row width
102    ell_width: Option<usize>,
103
104    /// Compression statistics
105    compression_stats: CompressionStats,
106}
107
108/// Statistics about sparse tensor compression efficiency
109#[derive(Debug, Clone)]
110pub struct CompressionStats {
111    /// Theoretical dense size in bytes
112    dense_size_bytes: usize,
113
114    /// Actual sparse storage size in bytes
115    sparse_size_bytes: usize,
116
117    /// Compression ratio (dense_size / sparse_size)
118    compression_ratio: f32,
119
120    /// Memory overhead from indices storage
121    #[allow(dead_code)] // Index overhead tracking - future implementation
122    index_overhead_bytes: usize,
123}
124
125impl SparseMetadata {
126    /// Create new sparse metadata
127    pub fn new(
128        format: SparseFormat,
129        nnz: usize,
130        total_elements: usize,
131        dense_size_bytes: usize,
132        sparse_size_bytes: usize,
133    ) -> Self {
134        let sparsity = 1.0 - (nnz as f32 / total_elements as f32);
135        let compression_ratio = dense_size_bytes as f32 / sparse_size_bytes as f32;
136
137        Self {
138            format,
139            nnz,
140            sparsity,
141            indices_sorted: false,
142            duplicates_summed: false,
143            block_size: None,
144            num_diagonals: None,
145            ell_width: None,
146            compression_stats: CompressionStats {
147                dense_size_bytes,
148                sparse_size_bytes,
149                compression_ratio,
150                index_overhead_bytes: sparse_size_bytes - (nnz * 4), // Rough estimate
151            },
152        }
153    }
154
155    /// Get storage format
156    pub fn format(&self) -> SparseFormat {
157        self.format
158    }
159
160    /// Get number of non-zero elements
161    pub fn nnz(&self) -> usize {
162        self.nnz
163    }
164
165    /// Get sparsity ratio (fraction of zero elements)
166    pub fn sparsity(&self) -> f32 {
167        self.sparsity
168    }
169
170    /// Get density ratio (fraction of non-zero elements)
171    pub fn density(&self) -> f32 {
172        1.0 - self.sparsity
173    }
174
175    /// Check if indices are sorted
176    pub fn indices_sorted(&self) -> bool {
177        self.indices_sorted
178    }
179
180    /// Mark indices as sorted
181    pub fn set_indices_sorted(&mut self, sorted: bool) {
182        self.indices_sorted = sorted;
183    }
184
185    /// Check if duplicates have been summed
186    pub fn duplicates_summed(&self) -> bool {
187        self.duplicates_summed
188    }
189
190    /// Mark duplicates as summed
191    pub fn set_duplicates_summed(&mut self, summed: bool) {
192        self.duplicates_summed = summed;
193    }
194
195    /// Get block size for BSR format
196    pub fn block_size(&self) -> Option<(usize, usize)> {
197        self.block_size
198    }
199
200    /// Set block size for BSR format
201    pub fn set_block_size(&mut self, size: (usize, usize)) {
202        self.block_size = Some(size);
203    }
204
205    /// Get compression statistics
206    pub fn compression_stats(&self) -> &CompressionStats {
207        &self.compression_stats
208    }
209
210    /// Check if sparse representation is beneficial
211    pub fn is_beneficial(&self) -> bool {
212        self.compression_stats.compression_ratio > 1.2 // At least 20% savings
213    }
214
215    /// Estimate memory savings compared to dense representation
216    pub fn memory_savings_bytes(&self) -> i64 {
217        self.compression_stats.dense_size_bytes as i64
218            - self.compression_stats.sparse_size_bytes as i64
219    }
220
221    /// Get format-specific information as string
222    pub fn format_info(&self) -> String {
223        match self.format {
224            SparseFormat::BSR => {
225                if let Some((bm, bn)) = self.block_size {
226                    format!("BSR({}x{})", bm, bn)
227                } else {
228                    "BSR".to_string()
229                }
230            }
231            SparseFormat::DIA => {
232                if let Some(ndiag) = self.num_diagonals {
233                    format!("DIA({})", ndiag)
234                } else {
235                    "DIA".to_string()
236                }
237            }
238            SparseFormat::ELL => {
239                if let Some(width) = self.ell_width {
240                    format!("ELL({})", width)
241                } else {
242                    "ELL".to_string()
243                }
244            }
245            _ => self.format.name().to_string(),
246        }
247    }
248}
249
250impl fmt::Display for SparseMetadata {
251    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
252        write!(
253            f,
254            "SparseMetadata({}, nnz={}, sparsity={:.2}%, compression={:.1}x)",
255            self.format_info(),
256            self.nnz,
257            self.sparsity * 100.0,
258            self.compression_stats.compression_ratio
259        )
260    }
261}
262
263/// COO (Coordinate) format sparse tensor indices
264#[derive(Debug, Clone)]
265pub struct CooIndices {
266    /// Row indices (length = nnz)
267    pub rows: Vec<usize>,
268
269    /// Column indices (length = nnz)
270    pub cols: Vec<usize>,
271
272    /// Higher dimension indices for tensors with ndim > 2
273    pub extra_dims: Vec<Vec<usize>>,
274}
275
276impl CooIndices {
277    /// Create new COO indices for 2D tensor
278    pub fn new_2d(rows: Vec<usize>, cols: Vec<usize>) -> Self {
279        assert_eq!(
280            rows.len(),
281            cols.len(),
282            "Row and column indices must have same length"
283        );
284
285        Self {
286            rows,
287            cols,
288            extra_dims: Vec::new(),
289        }
290    }
291
292    /// Create new COO indices for N-D tensor
293    pub fn new_nd(indices: Vec<Vec<usize>>) -> Self {
294        let nnz = indices.first().map_or(0, |dim| dim.len());
295
296        // Validate all dimensions have same length
297        for (i, dim_indices) in indices.iter().enumerate() {
298            assert_eq!(
299                dim_indices.len(),
300                nnz,
301                "Dimension {} indices length mismatch: expected {}, got {}",
302                i,
303                nnz,
304                dim_indices.len()
305            );
306        }
307
308        if indices.len() < 2 {
309            panic!("N-D tensor must have at least 2 dimensions");
310        }
311
312        Self {
313            rows: indices[0].clone(),
314            cols: indices[1].clone(),
315            extra_dims: if indices.len() > 2 {
316                indices[2..].to_vec()
317            } else {
318                Vec::new()
319            },
320        }
321    }
322
323    /// Get number of non-zero elements
324    pub fn nnz(&self) -> usize {
325        self.rows.len()
326    }
327
328    /// Get number of dimensions
329    pub fn ndim(&self) -> usize {
330        2 + self.extra_dims.len()
331    }
332
333    /// Check if indices are sorted in lexicographic order
334    pub fn is_sorted(&self) -> bool {
335        for i in 1..self.rows.len() {
336            if self.rows[i] < self.rows[i - 1] {
337                return false;
338            }
339            if self.rows[i] == self.rows[i - 1] && self.cols[i] < self.cols[i - 1] {
340                return false;
341            }
342        }
343        true
344    }
345
346    /// Sort indices in lexicographic order, returning permutation
347    pub fn sort(&mut self) -> Vec<usize> {
348        let mut perm: Vec<usize> = (0..self.nnz()).collect();
349
350        // Sort by lexicographic order
351        perm.sort_by(|&a, &b| {
352            // Compare rows first
353            match self.rows[a].cmp(&self.rows[b]) {
354                std::cmp::Ordering::Equal => {
355                    // Rows equal, compare columns
356                    match self.cols[a].cmp(&self.cols[b]) {
357                        std::cmp::Ordering::Equal => {
358                            // Compare extra dimensions
359                            for dim_indices in &self.extra_dims {
360                                match dim_indices[a].cmp(&dim_indices[b]) {
361                                    std::cmp::Ordering::Equal => continue,
362                                    other => return other,
363                                }
364                            }
365                            std::cmp::Ordering::Equal
366                        }
367                        other => other,
368                    }
369                }
370                other => other,
371            }
372        });
373
374        // Apply permutation
375        let orig_rows = self.rows.clone();
376        let orig_cols = self.cols.clone();
377        let orig_extra: Vec<_> = self.extra_dims.clone();
378
379        for (i, &p) in perm.iter().enumerate() {
380            self.rows[i] = orig_rows[p];
381            self.cols[i] = orig_cols[p];
382            for (dim_idx, orig_dim) in orig_extra.iter().enumerate() {
383                self.extra_dims[dim_idx][i] = orig_dim[p];
384            }
385        }
386
387        perm
388    }
389}
390
391/// CSR (Compressed Sparse Row) format indices
392#[derive(Debug, Clone)]
393pub struct CsrIndices {
394    /// Row pointers (length = nrows + 1)
395    pub row_ptrs: Vec<usize>,
396
397    /// Column indices (length = nnz)
398    pub col_indices: Vec<usize>,
399}
400
401impl CsrIndices {
402    /// Create new CSR indices
403    pub fn new(row_ptrs: Vec<usize>, col_indices: Vec<usize>) -> Self {
404        // Validate structure
405        let nnz = col_indices.len();
406        let _nrows = row_ptrs.len().saturating_sub(1);
407
408        assert_eq!(
409            *row_ptrs.last().unwrap_or(&0),
410            nnz,
411            "Last row pointer must equal nnz"
412        );
413
414        // Validate row pointers are non-decreasing
415        for i in 1..row_ptrs.len() {
416            assert!(
417                row_ptrs[i] >= row_ptrs[i - 1],
418                "Row pointers must be non-decreasing"
419            );
420        }
421
422        Self {
423            row_ptrs,
424            col_indices,
425        }
426    }
427
428    /// Convert from COO format
429    pub fn from_coo(coo: &CooIndices, nrows: usize) -> Self {
430        let _nnz = coo.nnz();
431        let mut row_ptrs = vec![0; nrows + 1];
432
433        // Count elements per row
434        for &row in &coo.rows {
435            if row < nrows {
436                row_ptrs[row + 1] += 1;
437            }
438        }
439
440        // Convert counts to cumulative sums
441        for i in 1..=nrows {
442            row_ptrs[i] += row_ptrs[i - 1];
443        }
444
445        // Create column indices array (assume COO is already sorted)
446        let col_indices = coo.cols.clone();
447
448        Self::new(row_ptrs, col_indices)
449    }
450
451    /// Get number of rows
452    pub fn nrows(&self) -> usize {
453        self.row_ptrs.len().saturating_sub(1)
454    }
455
456    /// Get number of non-zero elements
457    pub fn nnz(&self) -> usize {
458        self.col_indices.len()
459    }
460
461    /// Get range of column indices for a row
462    pub fn row_range(&self, row: usize) -> Option<std::ops::Range<usize>> {
463        if row >= self.nrows() {
464            return None;
465        }
466        Some(self.row_ptrs[row]..self.row_ptrs[row + 1])
467    }
468}
469
470/// Sparse tensor storage trait
471pub trait SparseStorage: Send + Sync + std::fmt::Debug {
472    /// Get sparse metadata
473    fn metadata(&self) -> &SparseMetadata;
474
475    /// Get element count
476    fn nnz(&self) -> usize {
477        self.metadata().nnz()
478    }
479
480    /// Get storage format
481    fn format(&self) -> SparseFormat {
482        self.metadata().format()
483    }
484
485    /// Check if representation is beneficial vs dense
486    fn is_beneficial(&self) -> bool {
487        self.metadata().is_beneficial()
488    }
489
490    /// Convert to COO format if possible
491    fn to_coo(&self) -> Result<Arc<dyn SparseStorage>, TorshError>;
492
493    /// Convert to CSR format if possible
494    fn to_csr(&self) -> Result<Arc<dyn SparseStorage>, TorshError>;
495
496    /// Get memory usage in bytes
497    fn memory_usage(&self) -> usize;
498}
499
500/// COO format sparse storage
501#[derive(Debug)]
502pub struct CooStorage {
503    metadata: SparseMetadata,
504    indices: CooIndices,
505    values: Vec<u8>, // Raw bytes for type-erased storage
506    dtype: DType,
507    shape: Shape,
508}
509
510impl CooStorage {
511    /// Create new COO storage
512    pub fn new(
513        indices: CooIndices,
514        values: Vec<u8>,
515        dtype: DType,
516        shape: Shape,
517    ) -> Result<Self, TorshError> {
518        let nnz = indices.nnz();
519        let expected_value_size = nnz * dtype.size();
520
521        if values.len() != expected_value_size {
522            return Err(TorshError::InvalidArgument(format!(
523                "Value buffer size mismatch: expected {}, got {}",
524                expected_value_size,
525                values.len()
526            )));
527        }
528
529        let total_elements: usize = shape.dims().iter().product();
530        let dense_size = total_elements * dtype.size();
531        let sparse_size = values.len() + indices.rows.len() * 8 + indices.cols.len() * 8; // Rough estimate
532
533        let metadata = SparseMetadata::new(
534            SparseFormat::COO,
535            nnz,
536            total_elements,
537            dense_size,
538            sparse_size,
539        );
540
541        Ok(Self {
542            metadata,
543            indices,
544            values,
545            dtype,
546            shape,
547        })
548    }
549
550    /// Get indices reference
551    pub fn indices(&self) -> &CooIndices {
552        &self.indices
553    }
554
555    /// Get mutable indices reference
556    pub fn indices_mut(&mut self) -> &mut CooIndices {
557        &mut self.indices
558    }
559
560    /// Get values as raw bytes
561    pub fn values_bytes(&self) -> &[u8] {
562        &self.values
563    }
564
565    /// Get data type
566    pub fn dtype(&self) -> DType {
567        self.dtype
568    }
569
570    /// Get shape
571    pub fn shape(&self) -> &Shape {
572        &self.shape
573    }
574}
575
576impl SparseStorage for CooStorage {
577    fn metadata(&self) -> &SparseMetadata {
578        &self.metadata
579    }
580
581    fn to_coo(&self) -> Result<Arc<dyn SparseStorage>, TorshError> {
582        // Already COO format, return clone
583        Ok(Arc::new(CooStorage {
584            metadata: self.metadata.clone(),
585            indices: self.indices.clone(),
586            values: self.values.clone(),
587            dtype: self.dtype,
588            shape: self.shape.clone(),
589        }))
590    }
591
592    fn to_csr(&self) -> Result<Arc<dyn SparseStorage>, TorshError> {
593        if self.shape.ndim() != 2 {
594            return Err(TorshError::InvalidArgument(
595                "CSR format only supports 2D tensors".to_string(),
596            ));
597        }
598
599        let nrows = self.shape.dims()[0];
600        let csr_indices = CsrIndices::from_coo(&self.indices, nrows);
601
602        Ok(Arc::new(CsrStorage {
603            metadata: {
604                let mut meta = self.metadata.clone();
605                meta.format = SparseFormat::CSR;
606                meta
607            },
608            indices: csr_indices,
609            values: self.values.clone(),
610            dtype: self.dtype,
611            shape: self.shape.clone(),
612        }))
613    }
614
615    fn memory_usage(&self) -> usize {
616        self.values.len()
617            + self.indices.rows.len() * std::mem::size_of::<usize>()
618            + self.indices.cols.len() * std::mem::size_of::<usize>()
619            + self
620                .indices
621                .extra_dims
622                .iter()
623                .map(|dim| dim.len() * std::mem::size_of::<usize>())
624                .sum::<usize>()
625    }
626}
627
628/// CSR format sparse storage
629#[derive(Debug)]
630pub struct CsrStorage {
631    metadata: SparseMetadata,
632    indices: CsrIndices,
633    values: Vec<u8>,
634    dtype: DType,
635    shape: Shape,
636}
637
638impl CsrStorage {
639    /// Create new CSR storage
640    pub fn new(
641        indices: CsrIndices,
642        values: Vec<u8>,
643        dtype: DType,
644        shape: Shape,
645    ) -> Result<Self, TorshError> {
646        if shape.ndim() != 2 {
647            return Err(TorshError::InvalidArgument(
648                "CSR format only supports 2D tensors".to_string(),
649            ));
650        }
651
652        let nnz = indices.nnz();
653        let expected_value_size = nnz * dtype.size();
654
655        if values.len() != expected_value_size {
656            return Err(TorshError::InvalidArgument(format!(
657                "Value buffer size mismatch: expected {}, got {}",
658                expected_value_size,
659                values.len()
660            )));
661        }
662
663        let total_elements: usize = shape.dims().iter().product();
664        let dense_size = total_elements * dtype.size();
665        let sparse_size = values.len() + indices.row_ptrs.len() * 8 + indices.col_indices.len() * 8;
666
667        let metadata = SparseMetadata::new(
668            SparseFormat::CSR,
669            nnz,
670            total_elements,
671            dense_size,
672            sparse_size,
673        );
674
675        Ok(Self {
676            metadata,
677            indices,
678            values,
679            dtype,
680            shape,
681        })
682    }
683
684    /// Get indices reference
685    pub fn indices(&self) -> &CsrIndices {
686        &self.indices
687    }
688}
689
690impl SparseStorage for CsrStorage {
691    fn metadata(&self) -> &SparseMetadata {
692        &self.metadata
693    }
694
695    fn to_coo(&self) -> Result<Arc<dyn SparseStorage>, TorshError> {
696        // Convert CSR back to COO
697        let mut rows = Vec::with_capacity(self.nnz());
698        let mut cols = Vec::with_capacity(self.nnz());
699
700        for row in 0..self.indices.nrows() {
701            let range = self.indices.row_range(row).unwrap();
702            for col_idx in range {
703                rows.push(row);
704                cols.push(self.indices.col_indices[col_idx]);
705            }
706        }
707
708        let coo_indices = CooIndices::new_2d(rows, cols);
709
710        Ok(Arc::new(CooStorage {
711            metadata: {
712                let mut meta = self.metadata.clone();
713                meta.format = SparseFormat::COO;
714                meta
715            },
716            indices: coo_indices,
717            values: self.values.clone(),
718            dtype: self.dtype,
719            shape: self.shape.clone(),
720        }))
721    }
722
723    fn to_csr(&self) -> Result<Arc<dyn SparseStorage>, TorshError> {
724        // Already CSR format
725        Ok(Arc::new(CsrStorage {
726            metadata: self.metadata.clone(),
727            indices: self.indices.clone(),
728            values: self.values.clone(),
729            dtype: self.dtype,
730            shape: self.shape.clone(),
731        }))
732    }
733
734    fn memory_usage(&self) -> usize {
735        self.values.len()
736            + self.indices.row_ptrs.len() * std::mem::size_of::<usize>()
737            + self.indices.col_indices.len() * std::mem::size_of::<usize>()
738    }
739}
740
741/// Utilities for sparse tensor operations
742pub mod utils {
743    use super::*;
744
745    /// Analyze sparsity patterns in dense data
746    pub fn analyze_sparsity(data: &[f32], shape: &[usize]) -> SparseAnalysis {
747        let total_elements = data.len();
748        let mut nnz = 0;
749        let mut pattern_info = PatternInfo::default();
750
751        // Count non-zeros and analyze patterns
752        for (idx, &value) in data.iter().enumerate() {
753            if value != 0.0 {
754                nnz += 1;
755                pattern_info.update(idx, shape);
756            }
757        }
758
759        let sparsity = 1.0 - (nnz as f32 / total_elements as f32);
760
761        SparseAnalysis {
762            sparsity,
763            nnz,
764            total_elements,
765            pattern_info,
766        }
767    }
768
769    /// Recommend optimal sparse format based on sparsity analysis
770    pub fn recommend_format(analysis: &SparseAnalysis, shape: &[usize]) -> FormatRecommendation {
771        let sparsity = analysis.sparsity;
772        let nnz = analysis.nnz;
773
774        // Simple heuristics for format recommendation
775        if sparsity < 0.5 {
776            return FormatRecommendation {
777                format: None, // Dense is better
778                reason: "Low sparsity, dense representation more efficient".to_string(),
779                confidence: 0.9,
780            };
781        }
782
783        if shape.len() == 2 {
784            // 2D matrix
785            let (nrows, ncols) = (shape[0], shape[1]);
786
787            if analysis.pattern_info.has_structured_rows {
788                return FormatRecommendation {
789                    format: Some(SparseFormat::CSR),
790                    reason: "Good row locality, CSR optimal for row-wise operations".to_string(),
791                    confidence: 0.8,
792                };
793            }
794
795            if analysis.pattern_info.has_structured_cols {
796                return FormatRecommendation {
797                    format: Some(SparseFormat::CSC),
798                    reason: "Good column locality, CSC optimal for column-wise operations"
799                        .to_string(),
800                    confidence: 0.8,
801                };
802            }
803
804            if nnz < (nrows + ncols) * 10 {
805                return FormatRecommendation {
806                    format: Some(SparseFormat::COO),
807                    reason: "Very sparse matrix, COO has lowest overhead".to_string(),
808                    confidence: 0.7,
809                };
810            }
811
812            return FormatRecommendation {
813                format: Some(SparseFormat::CSR),
814                reason: "General 2D sparse matrix, CSR is default choice".to_string(),
815                confidence: 0.6,
816            };
817        }
818
819        // N-D tensor
820        FormatRecommendation {
821            format: Some(SparseFormat::COO),
822            reason: "Multi-dimensional tensor, COO supports arbitrary dimensions".to_string(),
823            confidence: 0.8,
824        }
825    }
826
827    /// Convert dense data to optimal sparse format
828    pub fn densify_to_sparse<T>(
829        data: &[T],
830        shape: &Shape,
831        dtype: DType,
832        threshold: Option<f64>,
833    ) -> Result<Arc<dyn SparseStorage>, TorshError>
834    where
835        T: Clone + PartialEq + Into<f64> + Default,
836    {
837        let threshold = threshold.unwrap_or(1e-12);
838        let zero = T::default();
839
840        // Find non-zero elements
841        let mut indices = Vec::new();
842        let mut values = Vec::new();
843
844        for (linear_idx, value) in data.iter().enumerate() {
845            let abs_val = value.clone().into().abs();
846            if abs_val > threshold && *value != zero {
847                // Convert linear index to multi-dimensional indices
848                let multi_idx = linear_to_multidim(linear_idx, shape.dims());
849                indices.push(multi_idx);
850                values.push(value.clone());
851            }
852        }
853
854        if indices.is_empty() {
855            return Err(TorshError::InvalidArgument(
856                "No non-zero elements found".to_string(),
857            ));
858        }
859
860        // Convert to bytes (simplified - real implementation would handle type properly)
861        let value_bytes: Vec<u8> = values
862            .iter()
863            .flat_map(|v| {
864                let val_f64 = v.clone().into();
865                val_f64.to_ne_bytes()
866            })
867            .collect();
868
869        // Create COO indices
870        let dims = shape.dims();
871        match dims.len() {
872            1 => {
873                let rows: Vec<usize> = indices.iter().map(|idx| idx[0]).collect();
874                let cols = vec![0; rows.len()]; // Dummy column for 1D
875                let coo_indices = CooIndices::new_2d(rows, cols);
876                CooStorage::new(coo_indices, value_bytes, dtype, shape.clone())
877                    .map(|storage| Arc::new(storage) as Arc<dyn SparseStorage>)
878            }
879            2 => {
880                let rows: Vec<usize> = indices.iter().map(|idx| idx[0]).collect();
881                let cols: Vec<usize> = indices.iter().map(|idx| idx[1]).collect();
882                let coo_indices = CooIndices::new_2d(rows, cols);
883                CooStorage::new(coo_indices, value_bytes, dtype, shape.clone())
884                    .map(|storage| Arc::new(storage) as Arc<dyn SparseStorage>)
885            }
886            _ => {
887                let transposed_indices: Vec<Vec<usize>> = (0..dims.len())
888                    .map(|dim| indices.iter().map(|idx| idx[dim]).collect())
889                    .collect();
890                let coo_indices = CooIndices::new_nd(transposed_indices);
891                CooStorage::new(coo_indices, value_bytes, dtype, shape.clone())
892                    .map(|storage| Arc::new(storage) as Arc<dyn SparseStorage>)
893            }
894        }
895    }
896
897    // Helper function to convert linear index to multi-dimensional
898    fn linear_to_multidim(linear_idx: usize, shape: &[usize]) -> Vec<usize> {
899        let mut result = Vec::with_capacity(shape.len());
900        let mut remaining = linear_idx;
901
902        for &dim_size in shape.iter().rev() {
903            result.push(remaining % dim_size);
904            remaining /= dim_size;
905        }
906
907        result.reverse();
908        result
909    }
910
911    /// Analysis results for sparse data
912    #[derive(Debug, Clone)]
913    pub struct SparseAnalysis {
914        pub sparsity: f32,
915        pub nnz: usize,
916        pub total_elements: usize,
917        pub pattern_info: PatternInfo,
918    }
919
920    /// Information about sparsity patterns
921    #[derive(Debug, Clone, Default)]
922    pub struct PatternInfo {
923        pub has_structured_rows: bool,
924        pub has_structured_cols: bool,
925        pub has_diagonal_structure: bool,
926        pub has_block_structure: bool,
927        pub block_size: Option<(usize, usize)>,
928    }
929
930    impl PatternInfo {
931        fn update(&mut self, idx: usize, shape: &[usize]) {
932            // Simplified pattern detection logic
933            // Real implementation would be more sophisticated
934            if shape.len() == 2 {
935                let (_nrows, ncols) = (shape[0], shape[1]);
936                let row = idx / ncols;
937                let col = idx % ncols;
938
939                // Check for diagonal elements
940                if row == col {
941                    self.has_diagonal_structure = true;
942                }
943
944                // Simple heuristics for structured patterns
945                if row.is_multiple_of(4) && col.is_multiple_of(4) {
946                    self.has_block_structure = true;
947                    self.block_size = Some((4, 4));
948                }
949            }
950        }
951    }
952
953    /// Format recommendation result
954    #[derive(Debug, Clone)]
955    pub struct FormatRecommendation {
956        pub format: Option<SparseFormat>,
957        pub reason: String,
958        pub confidence: f32, // 0.0 - 1.0
959    }
960}
961
962#[cfg(test)]
963mod tests {
964    use super::*;
965    use crate::shape::Shape;
966
967    #[test]
968    fn test_sparse_metadata_creation() {
969        let metadata = SparseMetadata::new(
970            SparseFormat::COO,
971            1000,  // nnz
972            10000, // total elements
973            40000, // dense size (10k * 4 bytes)
974            8000,  // sparse size
975        );
976
977        assert_eq!(metadata.format(), SparseFormat::COO);
978        assert_eq!(metadata.nnz(), 1000);
979        assert_eq!(metadata.sparsity(), 0.9); // 90% sparse
980        assert!(metadata.is_beneficial()); // 5x compression
981    }
982
983    #[test]
984    fn test_coo_indices_creation() {
985        let rows = vec![0, 1, 2, 1];
986        let cols = vec![1, 0, 2, 2];
987
988        let indices = CooIndices::new_2d(rows.clone(), cols.clone());
989
990        assert_eq!(indices.nnz(), 4);
991        assert_eq!(indices.ndim(), 2);
992        assert_eq!(indices.rows, rows);
993        assert_eq!(indices.cols, cols);
994    }
995
996    #[test]
997    fn test_coo_indices_sorting() {
998        let mut indices = CooIndices::new_2d(
999            vec![2, 1, 0, 1], // rows
1000            vec![0, 2, 1, 0], // cols
1001        );
1002
1003        assert!(!indices.is_sorted());
1004
1005        let _perm = indices.sort();
1006
1007        // After sorting: [(0,1), (1,0), (1,2), (2,0)]
1008        assert_eq!(indices.rows, vec![0, 1, 1, 2]);
1009        assert_eq!(indices.cols, vec![1, 0, 2, 0]);
1010        assert!(indices.is_sorted());
1011    }
1012
1013    #[test]
1014    fn test_csr_from_coo() {
1015        let coo_indices = CooIndices::new_2d(
1016            vec![0, 0, 1, 2, 2], // rows
1017            vec![1, 2, 0, 1, 2], // cols
1018        );
1019
1020        let csr_indices = CsrIndices::from_coo(&coo_indices, 3);
1021
1022        assert_eq!(csr_indices.nrows(), 3);
1023        assert_eq!(csr_indices.nnz(), 5);
1024        assert_eq!(csr_indices.row_ptrs, vec![0, 2, 3, 5]);
1025        assert_eq!(csr_indices.col_indices, vec![1, 2, 0, 1, 2]);
1026    }
1027
1028    #[test]
1029    fn test_coo_storage_creation() {
1030        let indices = CooIndices::new_2d(vec![0, 1], vec![1, 0]);
1031        let values = [1.0_f32.to_ne_bytes(), 2.0_f32.to_ne_bytes()].concat();
1032        let shape = Shape::new(vec![2, 2]);
1033
1034        let storage = CooStorage::new(indices, values, DType::F32, shape).unwrap();
1035
1036        assert_eq!(storage.nnz(), 2);
1037        assert_eq!(storage.format(), SparseFormat::COO);
1038        assert_eq!(storage.dtype(), DType::F32);
1039    }
1040
1041    #[test]
1042    fn test_format_conversion() {
1043        let indices = CooIndices::new_2d(vec![0, 1], vec![1, 0]);
1044        let values = [1.0_f32.to_ne_bytes(), 2.0_f32.to_ne_bytes()].concat();
1045        let shape = Shape::new(vec![2, 2]);
1046
1047        let coo_storage = CooStorage::new(indices, values, DType::F32, shape).unwrap();
1048
1049        // Convert COO -> CSR
1050        let csr_storage = coo_storage.to_csr().unwrap();
1051        assert_eq!(csr_storage.format(), SparseFormat::CSR);
1052
1053        // Convert CSR -> COO
1054        let coo_again = csr_storage.to_coo().unwrap();
1055        assert_eq!(coo_again.format(), SparseFormat::COO);
1056    }
1057
1058    #[test]
1059    fn test_sparsity_analysis() {
1060        let data = vec![0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0];
1061        let shape = vec![3, 3];
1062
1063        let analysis = utils::analyze_sparsity(&data, &shape);
1064
1065        assert_eq!(analysis.nnz, 3);
1066        assert_eq!(analysis.total_elements, 9);
1067        assert!((analysis.sparsity - 2.0 / 3.0).abs() < 1e-6);
1068    }
1069
1070    #[test]
1071    fn test_format_recommendation() {
1072        // High sparsity case
1073        let analysis = utils::SparseAnalysis {
1074            sparsity: 0.9,
1075            nnz: 100,
1076            total_elements: 1000,
1077            pattern_info: utils::PatternInfo::default(),
1078        };
1079
1080        let shape = vec![100, 10];
1081        let recommendation = utils::recommend_format(&analysis, &shape);
1082
1083        assert!(recommendation.format.is_some());
1084        assert!(recommendation.confidence > 0.0);
1085    }
1086}