Skip to main content

tensorlogic_infer/
sparse.rs

1//! Sparse tensor support for TensorLogic.
2//!
3//! This module provides comprehensive sparse tensor representations and operations:
4//! - **CSR** (Compressed Sparse Row) format for efficient row operations
5//! - **CSC** (Compressed Sparse Column) format for efficient column operations
6//! - **COO** (Coordinate) format for flexible construction
7//! - **Sparse-dense hybrid operations**
8//! - **Automatic sparsity detection and conversion**
9//! - **Sparse matrix multiplication and linear algebra**
10//!
11//! ## Example
12//!
13//! ```rust,ignore
14//! use tensorlogic_infer::{SparseFormat, SparseTensor, SparseMatrix};
15//!
16//! // Create a sparse matrix in COO format
17//! let mut builder = SparseTensor::builder(vec![100, 100], SparseFormat::COO);
18//! builder.add_entry(vec![5, 10], 3.14);
19//! builder.add_entry(vec![20, 30], 2.71);
20//! let sparse = builder.build()?;
21//!
22//! // Convert to CSR for efficient operations
23//! let csr = sparse.to_csr()?;
24//!
25//! // Sparse-dense multiplication
26//! let dense = vec![1.0; 100];
27//! let result = csr.multiply_dense(&dense)?;
28//!
29//! // Detect sparsity
30//! let sparsity = sparse.sparsity_ratio();
31//! println!("Sparsity: {:.2}%", sparsity * 100.0);
32//! ```
33
34use serde::{Deserialize, Serialize};
35use thiserror::Error;
36
37/// Sparse tensor errors.
38#[derive(Error, Debug, Clone, PartialEq)]
39pub enum SparseError {
40    #[error("Invalid sparse format conversion: {0} -> {1}")]
41    InvalidConversion(String, String),
42
43    #[error("Shape mismatch: expected {expected:?}, got {actual:?}")]
44    ShapeMismatch {
45        expected: Vec<usize>,
46        actual: Vec<usize>,
47    },
48
49    #[error("Index out of bounds: {index:?} for shape {shape:?}")]
50    IndexOutOfBounds {
51        index: Vec<usize>,
52        shape: Vec<usize>,
53    },
54
55    #[error("Invalid sparse tensor: {0}")]
56    Invalid(String),
57
58    #[error("Unsupported operation: {0}")]
59    UnsupportedOperation(String),
60
61    #[error("Empty sparse tensor")]
62    Empty,
63}
64
65/// Sparse tensor storage format.
66#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
67pub enum SparseFormat {
68    /// Compressed Sparse Row (CSR) - efficient for row-wise operations
69    CSR,
70    /// Compressed Sparse Column (CSC) - efficient for column-wise operations
71    CSC,
72    /// Coordinate (COO) - flexible for construction
73    COO,
74}
75
76impl SparseFormat {
77    /// Get the format name.
78    pub fn name(&self) -> &'static str {
79        match self {
80            SparseFormat::CSR => "CSR",
81            SparseFormat::CSC => "CSC",
82            SparseFormat::COO => "COO",
83        }
84    }
85
86    /// Check if this format is compressed.
87    pub fn is_compressed(&self) -> bool {
88        matches!(self, SparseFormat::CSR | SparseFormat::CSC)
89    }
90}
91
92/// Sparse matrix in CSR (Compressed Sparse Row) format.
93///
94/// Storage: O(nnz) where nnz is the number of non-zero elements
95/// Row access: O(1)
96/// Column access: O(nnz)
97#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
98pub struct SparseCSR {
99    /// Shape of the matrix (rows, cols)
100    pub shape: (usize, usize),
101    /// Row pointers (length = rows + 1)
102    pub row_ptr: Vec<usize>,
103    /// Column indices for each non-zero entry
104    pub col_indices: Vec<usize>,
105    /// Values for each non-zero entry
106    pub values: Vec<f64>,
107}
108
109impl SparseCSR {
110    /// Create a new empty CSR matrix.
111    pub fn new(rows: usize, cols: usize) -> Self {
112        Self {
113            shape: (rows, cols),
114            row_ptr: vec![0; rows + 1],
115            col_indices: Vec::new(),
116            values: Vec::new(),
117        }
118    }
119
120    /// Get the number of non-zero elements.
121    pub fn nnz(&self) -> usize {
122        self.values.len()
123    }
124
125    /// Get sparsity ratio (fraction of zero elements).
126    pub fn sparsity_ratio(&self) -> f64 {
127        let total = self.shape.0 * self.shape.1;
128        1.0 - (self.nnz() as f64 / total as f64)
129    }
130
131    /// Get a row slice.
132    pub fn row(&self, row_idx: usize) -> Result<Vec<(usize, f64)>, SparseError> {
133        if row_idx >= self.shape.0 {
134            return Err(SparseError::IndexOutOfBounds {
135                index: vec![row_idx],
136                shape: vec![self.shape.0],
137            });
138        }
139
140        let start = self.row_ptr[row_idx];
141        let end = self.row_ptr[row_idx + 1];
142
143        Ok((start..end)
144            .map(|i| (self.col_indices[i], self.values[i]))
145            .collect())
146    }
147
148    /// Multiply with a dense vector (matrix-vector multiplication).
149    pub fn multiply_dense(&self, vec: &[f64]) -> Result<Vec<f64>, SparseError> {
150        if vec.len() != self.shape.1 {
151            return Err(SparseError::ShapeMismatch {
152                expected: vec![self.shape.1],
153                actual: vec![vec.len()],
154            });
155        }
156
157        let mut result = vec![0.0; self.shape.0];
158
159        for row_idx in 0..self.shape.0 {
160            let start = self.row_ptr[row_idx];
161            let end = self.row_ptr[row_idx + 1];
162
163            let mut sum = 0.0;
164            for i in start..end {
165                sum += self.values[i] * vec[self.col_indices[i]];
166            }
167            result[row_idx] = sum;
168        }
169
170        Ok(result)
171    }
172
173    /// Transpose to CSC format.
174    pub fn transpose(&self) -> SparseCSC {
175        let mut csc = SparseCSC::new(self.shape.1, self.shape.0);
176        csc.col_ptr = vec![0; self.shape.1 + 1];
177
178        // Count entries per column
179        let mut counts = vec![0; self.shape.1];
180        for &col in &self.col_indices {
181            counts[col] += 1;
182        }
183
184        // Build column pointers
185        let mut sum = 0;
186        for i in 0..self.shape.1 {
187            csc.col_ptr[i] = sum;
188            sum += counts[i];
189        }
190        csc.col_ptr[self.shape.1] = sum;
191
192        // Fill in entries
193        csc.row_indices = vec![0; self.nnz()];
194        csc.values = vec![0.0; self.nnz()];
195        let mut positions = csc.col_ptr[..self.shape.1].to_vec();
196
197        for row in 0..self.shape.0 {
198            let start = self.row_ptr[row];
199            let end = self.row_ptr[row + 1];
200
201            for i in start..end {
202                let col = self.col_indices[i];
203                let pos = positions[col];
204                csc.row_indices[pos] = row;
205                csc.values[pos] = self.values[i];
206                positions[col] += 1;
207            }
208        }
209
210        csc
211    }
212
213    /// Get memory usage in bytes.
214    pub fn memory_bytes(&self) -> usize {
215        self.row_ptr.len() * std::mem::size_of::<usize>()
216            + self.col_indices.len() * std::mem::size_of::<usize>()
217            + self.values.len() * std::mem::size_of::<f64>()
218    }
219
220    /// Validate the CSR structure.
221    pub fn validate(&self) -> Result<(), SparseError> {
222        // Check row pointers
223        if self.row_ptr.len() != self.shape.0 + 1 {
224            return Err(SparseError::Invalid(format!(
225                "Invalid row_ptr length: expected {}, got {}",
226                self.shape.0 + 1,
227                self.row_ptr.len()
228            )));
229        }
230
231        // Check monotonicity
232        for i in 0..self.shape.0 {
233            if self.row_ptr[i] > self.row_ptr[i + 1] {
234                return Err(SparseError::Invalid(format!(
235                    "Non-monotonic row_ptr at index {}",
236                    i
237                )));
238            }
239        }
240
241        // Check bounds
242        if self.row_ptr[self.shape.0] != self.nnz() {
243            return Err(SparseError::Invalid(format!(
244                "Last row_ptr {} doesn't match nnz {}",
245                self.row_ptr[self.shape.0],
246                self.nnz()
247            )));
248        }
249
250        // Check column indices
251        for &col in &self.col_indices {
252            if col >= self.shape.1 {
253                return Err(SparseError::IndexOutOfBounds {
254                    index: vec![0, col],
255                    shape: vec![self.shape.0, self.shape.1],
256                });
257            }
258        }
259
260        Ok(())
261    }
262}
263
264/// Sparse matrix in CSC (Compressed Sparse Column) format.
265///
266/// Storage: O(nnz) where nnz is the number of non-zero elements
267/// Row access: O(nnz)
268/// Column access: O(1)
269#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
270pub struct SparseCSC {
271    /// Shape of the matrix (rows, cols)
272    pub shape: (usize, usize),
273    /// Column pointers (length = cols + 1)
274    pub col_ptr: Vec<usize>,
275    /// Row indices for each non-zero entry
276    pub row_indices: Vec<usize>,
277    /// Values for each non-zero entry
278    pub values: Vec<f64>,
279}
280
281impl SparseCSC {
282    /// Create a new empty CSC matrix.
283    pub fn new(rows: usize, cols: usize) -> Self {
284        Self {
285            shape: (rows, cols),
286            col_ptr: vec![0; cols + 1],
287            row_indices: Vec::new(),
288            values: Vec::new(),
289        }
290    }
291
292    /// Get the number of non-zero elements.
293    pub fn nnz(&self) -> usize {
294        self.values.len()
295    }
296
297    /// Get sparsity ratio.
298    pub fn sparsity_ratio(&self) -> f64 {
299        let total = self.shape.0 * self.shape.1;
300        1.0 - (self.nnz() as f64 / total as f64)
301    }
302
303    /// Get a column slice.
304    pub fn column(&self, col_idx: usize) -> Result<Vec<(usize, f64)>, SparseError> {
305        if col_idx >= self.shape.1 {
306            return Err(SparseError::IndexOutOfBounds {
307                index: vec![col_idx],
308                shape: vec![self.shape.1],
309            });
310        }
311
312        let start = self.col_ptr[col_idx];
313        let end = self.col_ptr[col_idx + 1];
314
315        Ok((start..end)
316            .map(|i| (self.row_indices[i], self.values[i]))
317            .collect())
318    }
319
320    /// Transpose to CSR format.
321    pub fn transpose(&self) -> SparseCSR {
322        let mut csr = SparseCSR::new(self.shape.1, self.shape.0);
323        csr.row_ptr = self.col_ptr.clone();
324        csr.col_indices = self.row_indices.clone();
325        csr.values = self.values.clone();
326        csr
327    }
328}
329
330/// Sparse matrix in COO (Coordinate) format.
331///
332/// Storage: O(nnz) where nnz is the number of non-zero elements
333/// Random access: O(nnz)
334/// Best for: Construction and modification
335#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
336pub struct SparseCOO {
337    /// Shape of the matrix (rows, cols)
338    pub shape: (usize, usize),
339    /// Row indices
340    pub row_indices: Vec<usize>,
341    /// Column indices
342    pub col_indices: Vec<usize>,
343    /// Values
344    pub values: Vec<f64>,
345}
346
347impl SparseCOO {
348    /// Create a new empty COO matrix.
349    pub fn new(rows: usize, cols: usize) -> Self {
350        Self {
351            shape: (rows, cols),
352            row_indices: Vec::new(),
353            col_indices: Vec::new(),
354            values: Vec::new(),
355        }
356    }
357
358    /// Add a non-zero entry.
359    pub fn add_entry(&mut self, row: usize, col: usize, value: f64) -> Result<(), SparseError> {
360        if row >= self.shape.0 || col >= self.shape.1 {
361            return Err(SparseError::IndexOutOfBounds {
362                index: vec![row, col],
363                shape: vec![self.shape.0, self.shape.1],
364            });
365        }
366
367        self.row_indices.push(row);
368        self.col_indices.push(col);
369        self.values.push(value);
370
371        Ok(())
372    }
373
374    /// Get the number of non-zero elements.
375    pub fn nnz(&self) -> usize {
376        self.values.len()
377    }
378
379    /// Get sparsity ratio.
380    pub fn sparsity_ratio(&self) -> f64 {
381        let total = self.shape.0 * self.shape.1;
382        1.0 - (self.nnz() as f64 / total as f64)
383    }
384
385    /// Convert to CSR format.
386    pub fn to_csr(&self) -> SparseCSR {
387        let mut csr = SparseCSR::new(self.shape.0, self.shape.1);
388
389        // Create a sorted list of (row, col, value) tuples
390        let mut entries: Vec<_> = (0..self.nnz())
391            .map(|i| (self.row_indices[i], self.col_indices[i], self.values[i]))
392            .collect();
393        entries.sort_by_key(|(r, c, _)| (*r, *c));
394
395        // Build CSR structure
396        csr.row_ptr = vec![0; self.shape.0 + 1];
397        csr.col_indices = Vec::with_capacity(entries.len());
398        csr.values = Vec::with_capacity(entries.len());
399
400        let mut current_row = 0;
401        for (row, col, val) in entries {
402            while current_row < row {
403                current_row += 1;
404                csr.row_ptr[current_row] = csr.col_indices.len();
405            }
406            csr.col_indices.push(col);
407            csr.values.push(val);
408        }
409
410        // Fill remaining row pointers
411        for i in current_row + 1..=self.shape.0 {
412            csr.row_ptr[i] = csr.col_indices.len();
413        }
414
415        csr
416    }
417
418    /// Convert to CSC format.
419    pub fn to_csc(&self) -> SparseCSC {
420        let mut csc = SparseCSC::new(self.shape.0, self.shape.1);
421
422        // Create a sorted list of (col, row, value) tuples
423        let mut entries: Vec<_> = (0..self.nnz())
424            .map(|i| (self.col_indices[i], self.row_indices[i], self.values[i]))
425            .collect();
426        entries.sort_by_key(|(c, r, _)| (*c, *r));
427
428        // Build CSC structure
429        csc.col_ptr = vec![0; self.shape.1 + 1];
430        csc.row_indices = Vec::with_capacity(entries.len());
431        csc.values = Vec::with_capacity(entries.len());
432
433        let mut current_col = 0;
434        for (col, row, val) in entries {
435            while current_col < col {
436                current_col += 1;
437                csc.col_ptr[current_col] = csc.row_indices.len();
438            }
439            csc.row_indices.push(row);
440            csc.values.push(val);
441        }
442
443        // Fill remaining column pointers
444        for i in current_col + 1..=self.shape.1 {
445            csc.col_ptr[i] = csc.row_indices.len();
446        }
447
448        csc
449    }
450}
451
452/// Sparse tensor representation.
453#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
454pub enum SparseTensor {
455    /// 2D matrix in CSR format
456    CSR(SparseCSR),
457    /// 2D matrix in CSC format
458    CSC(SparseCSC),
459    /// 2D matrix in COO format
460    COO(SparseCOO),
461}
462
463impl SparseTensor {
464    /// Create a sparse tensor builder.
465    pub fn builder(shape: Vec<usize>, format: SparseFormat) -> SparseTensorBuilder {
466        SparseTensorBuilder::new(shape, format)
467    }
468
469    /// Get the sparse format.
470    pub fn format(&self) -> SparseFormat {
471        match self {
472            SparseTensor::CSR(_) => SparseFormat::CSR,
473            SparseTensor::CSC(_) => SparseFormat::CSC,
474            SparseTensor::COO(_) => SparseFormat::COO,
475        }
476    }
477
478    /// Get the shape.
479    pub fn shape(&self) -> Vec<usize> {
480        match self {
481            SparseTensor::CSR(m) => vec![m.shape.0, m.shape.1],
482            SparseTensor::CSC(m) => vec![m.shape.0, m.shape.1],
483            SparseTensor::COO(m) => vec![m.shape.0, m.shape.1],
484        }
485    }
486
487    /// Get the number of non-zero elements.
488    pub fn nnz(&self) -> usize {
489        match self {
490            SparseTensor::CSR(m) => m.nnz(),
491            SparseTensor::CSC(m) => m.nnz(),
492            SparseTensor::COO(m) => m.nnz(),
493        }
494    }
495
496    /// Get sparsity ratio.
497    pub fn sparsity_ratio(&self) -> f64 {
498        match self {
499            SparseTensor::CSR(m) => m.sparsity_ratio(),
500            SparseTensor::CSC(m) => m.sparsity_ratio(),
501            SparseTensor::COO(m) => m.sparsity_ratio(),
502        }
503    }
504
505    /// Convert to CSR format.
506    pub fn to_csr(&self) -> Result<SparseTensor, SparseError> {
507        match self {
508            SparseTensor::CSR(_) => Ok(self.clone()),
509            SparseTensor::CSC(m) => Ok(SparseTensor::CSR(m.transpose())),
510            SparseTensor::COO(m) => Ok(SparseTensor::CSR(m.to_csr())),
511        }
512    }
513
514    /// Convert to CSC format.
515    pub fn to_csc(&self) -> Result<SparseTensor, SparseError> {
516        match self {
517            SparseTensor::CSR(m) => Ok(SparseTensor::CSC(m.transpose())),
518            SparseTensor::CSC(_) => Ok(self.clone()),
519            SparseTensor::COO(m) => Ok(SparseTensor::CSC(m.to_csc())),
520        }
521    }
522
523    /// Convert to COO format.
524    pub fn to_coo(&self) -> Result<SparseTensor, SparseError> {
525        match self {
526            SparseTensor::COO(_) => Ok(self.clone()),
527            SparseTensor::CSR(m) => {
528                let mut coo = SparseCOO::new(m.shape.0, m.shape.1);
529                for row in 0..m.shape.0 {
530                    let start = m.row_ptr[row];
531                    let end = m.row_ptr[row + 1];
532                    for i in start..end {
533                        coo.add_entry(row, m.col_indices[i], m.values[i])?;
534                    }
535                }
536                Ok(SparseTensor::COO(coo))
537            }
538            SparseTensor::CSC(m) => {
539                let mut coo = SparseCOO::new(m.shape.0, m.shape.1);
540                for col in 0..m.shape.1 {
541                    let start = m.col_ptr[col];
542                    let end = m.col_ptr[col + 1];
543                    for i in start..end {
544                        coo.add_entry(m.row_indices[i], col, m.values[i])?;
545                    }
546                }
547                Ok(SparseTensor::COO(coo))
548            }
549        }
550    }
551
552    /// Get memory usage in bytes.
553    pub fn memory_bytes(&self) -> usize {
554        match self {
555            SparseTensor::CSR(m) => m.memory_bytes(),
556            SparseTensor::CSC(m) => {
557                m.col_ptr.len() * std::mem::size_of::<usize>()
558                    + m.row_indices.len() * std::mem::size_of::<usize>()
559                    + m.values.len() * std::mem::size_of::<f64>()
560            }
561            SparseTensor::COO(m) => {
562                (m.row_indices.len() + m.col_indices.len()) * std::mem::size_of::<usize>()
563                    + m.values.len() * std::mem::size_of::<f64>()
564            }
565        }
566    }
567}
568
569/// Builder for sparse tensors.
570pub struct SparseTensorBuilder {
571    shape: Vec<usize>,
572    format: SparseFormat,
573    entries: Vec<(Vec<usize>, f64)>,
574}
575
576impl SparseTensorBuilder {
577    /// Create a new sparse tensor builder.
578    pub fn new(shape: Vec<usize>, format: SparseFormat) -> Self {
579        Self {
580            shape,
581            format,
582            entries: Vec::new(),
583        }
584    }
585
586    /// Add a non-zero entry.
587    pub fn add_entry(&mut self, indices: Vec<usize>, value: f64) -> Result<(), SparseError> {
588        if indices.len() != self.shape.len() {
589            return Err(SparseError::ShapeMismatch {
590                expected: vec![self.shape.len()],
591                actual: vec![indices.len()],
592            });
593        }
594
595        for (i, &idx) in indices.iter().enumerate() {
596            if idx >= self.shape[i] {
597                return Err(SparseError::IndexOutOfBounds {
598                    index: indices.clone(),
599                    shape: self.shape.clone(),
600                });
601            }
602        }
603
604        self.entries.push((indices, value));
605        Ok(())
606    }
607
608    /// Build the sparse tensor.
609    pub fn build(self) -> Result<SparseTensor, SparseError> {
610        // Currently only support 2D tensors
611        if self.shape.len() != 2 {
612            return Err(SparseError::UnsupportedOperation(format!(
613                "Only 2D sparse tensors are supported, got shape {:?}",
614                self.shape
615            )));
616        }
617
618        let rows = self.shape[0];
619        let cols = self.shape[1];
620
621        // Build COO first
622        let mut coo = SparseCOO::new(rows, cols);
623        for (indices, value) in self.entries {
624            coo.add_entry(indices[0], indices[1], value)?;
625        }
626
627        // Convert to requested format
628        match self.format {
629            SparseFormat::COO => Ok(SparseTensor::COO(coo)),
630            SparseFormat::CSR => Ok(SparseTensor::CSR(coo.to_csr())),
631            SparseFormat::CSC => Ok(SparseTensor::CSC(coo.to_csc())),
632        }
633    }
634}
635
636/// Detect sparsity in a dense tensor.
637pub fn detect_sparsity(data: &[f64], threshold: f64) -> (usize, f64) {
638    let total = data.len();
639    let zeros = data.iter().filter(|&&x| x.abs() < threshold).count();
640    let sparsity = zeros as f64 / total as f64;
641    (zeros, sparsity)
642}
643
644/// Convert dense tensor to sparse if beneficial.
645pub fn to_sparse_if_beneficial(
646    data: &[f64],
647    shape: Vec<usize>,
648    threshold: f64,
649    min_sparsity: f64,
650) -> Result<Option<SparseTensor>, SparseError> {
651    let (_, sparsity) = detect_sparsity(data, threshold);
652
653    if sparsity < min_sparsity {
654        return Ok(None);
655    }
656
657    // Build sparse tensor
658    let mut builder = SparseTensor::builder(shape.clone(), SparseFormat::CSR);
659
660    if shape.len() == 2 {
661        let cols = shape[1];
662        for (i, &val) in data.iter().enumerate() {
663            if val.abs() >= threshold {
664                let row = i / cols;
665                let col = i % cols;
666                builder.add_entry(vec![row, col], val)?;
667            }
668        }
669    }
670
671    Ok(Some(builder.build()?))
672}
673
674#[cfg(test)]
675mod tests {
676    use super::*;
677
678    #[test]
679    fn test_sparse_format() {
680        assert_eq!(SparseFormat::CSR.name(), "CSR");
681        assert!(SparseFormat::CSR.is_compressed());
682        assert!(!SparseFormat::COO.is_compressed());
683    }
684
685    #[test]
686    fn test_sparse_coo_creation() {
687        let mut coo = SparseCOO::new(3, 3);
688        assert_eq!(coo.shape, (3, 3));
689        assert_eq!(coo.nnz(), 0);
690
691        coo.add_entry(0, 1, 5.0).unwrap();
692        coo.add_entry(1, 2, 3.0).unwrap();
693        assert_eq!(coo.nnz(), 2);
694    }
695
696    #[test]
697    fn test_sparse_coo_to_csr() {
698        let mut coo = SparseCOO::new(3, 3);
699        coo.add_entry(0, 0, 1.0).unwrap();
700        coo.add_entry(0, 2, 2.0).unwrap();
701        coo.add_entry(2, 1, 3.0).unwrap();
702
703        let csr = coo.to_csr();
704        assert_eq!(csr.shape, (3, 3));
705        assert_eq!(csr.nnz(), 3);
706        assert!(csr.validate().is_ok());
707    }
708
709    #[test]
710    fn test_sparse_csr_multiply_dense() {
711        let mut coo = SparseCOO::new(2, 3);
712        coo.add_entry(0, 0, 1.0).unwrap();
713        coo.add_entry(0, 2, 2.0).unwrap();
714        coo.add_entry(1, 1, 3.0).unwrap();
715
716        let csr = coo.to_csr();
717        let vec = vec![1.0, 2.0, 3.0];
718        let result = csr.multiply_dense(&vec).unwrap();
719
720        assert_eq!(result.len(), 2);
721        assert!((result[0] - 7.0).abs() < 1e-10); // 1*1 + 2*3 = 7
722        assert!((result[1] - 6.0).abs() < 1e-10); // 3*2 = 6
723    }
724
725    #[test]
726    fn test_sparse_csr_row_access() {
727        let mut coo = SparseCOO::new(3, 3);
728        coo.add_entry(0, 0, 1.0).unwrap();
729        coo.add_entry(0, 2, 2.0).unwrap();
730        coo.add_entry(1, 1, 3.0).unwrap();
731
732        let csr = coo.to_csr();
733        let row0 = csr.row(0).unwrap();
734        assert_eq!(row0.len(), 2);
735        assert_eq!(row0[0], (0, 1.0));
736        assert_eq!(row0[1], (2, 2.0));
737
738        let row1 = csr.row(1).unwrap();
739        assert_eq!(row1.len(), 1);
740        assert_eq!(row1[0], (1, 3.0));
741    }
742
743    #[test]
744    fn test_sparse_csr_transpose() {
745        let mut coo = SparseCOO::new(2, 3);
746        coo.add_entry(0, 0, 1.0).unwrap();
747        coo.add_entry(0, 2, 2.0).unwrap();
748        coo.add_entry(1, 1, 3.0).unwrap();
749
750        let csr = coo.to_csr();
751        let csc = csr.transpose();
752
753        assert_eq!(csc.shape, (3, 2));
754        assert_eq!(csc.nnz(), 3);
755    }
756
757    #[test]
758    fn test_sparsity_ratio() {
759        let mut coo = SparseCOO::new(10, 10);
760        coo.add_entry(0, 0, 1.0).unwrap();
761        coo.add_entry(5, 5, 2.0).unwrap();
762
763        let sparsity = coo.sparsity_ratio();
764        assert!((sparsity - 0.98).abs() < 0.01); // 98% sparse
765    }
766
767    #[test]
768    fn test_sparse_tensor_builder() {
769        let mut builder = SparseTensor::builder(vec![3, 3], SparseFormat::CSR);
770        builder.add_entry(vec![0, 0], 1.0).unwrap();
771        builder.add_entry(vec![1, 2], 2.0).unwrap();
772
773        let sparse = builder.build().unwrap();
774        assert_eq!(sparse.format(), SparseFormat::CSR);
775        assert_eq!(sparse.nnz(), 2);
776    }
777
778    #[test]
779    fn test_sparse_tensor_conversion() {
780        let mut builder = SparseTensor::builder(vec![3, 3], SparseFormat::COO);
781        builder.add_entry(vec![0, 0], 1.0).unwrap();
782        builder.add_entry(vec![1, 2], 2.0).unwrap();
783
784        let coo = builder.build().unwrap();
785        let csr = coo.to_csr().unwrap();
786        let csc = csr.to_csc().unwrap();
787
788        assert_eq!(coo.nnz(), 2);
789        assert_eq!(csr.nnz(), 2);
790        assert_eq!(csc.nnz(), 2);
791    }
792
793    #[test]
794    fn test_detect_sparsity() {
795        let data = vec![0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0];
796        let (zeros, sparsity) = detect_sparsity(&data, 1e-10);
797
798        assert_eq!(zeros, 6);
799        assert!((sparsity - 0.666).abs() < 0.01);
800    }
801
802    #[test]
803    fn test_to_sparse_if_beneficial() {
804        let data = vec![0.0, 1.0, 0.0, 0.0, 2.0, 0.0];
805        let shape = vec![2, 3];
806
807        let sparse = to_sparse_if_beneficial(&data, shape, 1e-10, 0.5).unwrap();
808        assert!(sparse.is_some());
809
810        let sparse = sparse.unwrap();
811        assert_eq!(sparse.nnz(), 2);
812        assert!(sparse.sparsity_ratio() > 0.5);
813    }
814
815    #[test]
816    fn test_sparse_csr_validation() {
817        let csr = SparseCSR {
818            shape: (3, 3),
819            row_ptr: vec![0, 2, 3, 3],
820            col_indices: vec![0, 2, 1],
821            values: vec![1.0, 2.0, 3.0],
822        };
823
824        assert!(csr.validate().is_ok());
825    }
826
827    #[test]
828    fn test_sparse_memory_usage() {
829        let mut builder = SparseTensor::builder(vec![100, 100], SparseFormat::CSR);
830        builder.add_entry(vec![0, 0], 1.0).unwrap();
831        builder.add_entry(vec![50, 50], 2.0).unwrap();
832
833        let sparse = builder.build().unwrap();
834        let memory = sparse.memory_bytes();
835
836        // Should be much less than dense 100x100 matrix
837        let dense_memory = 100 * 100 * std::mem::size_of::<f64>();
838        assert!(memory < dense_memory / 10);
839    }
840}