Skip to main content

torsh_tensor/
sparse.rs

1//! Sparse tensor implementation with multiple storage formats
2//!
3//! This module provides efficient sparse tensor operations for tensors with many zero elements.
4//! Supports three storage formats:
5//! - COO (Coordinate): General purpose, easy to construct and modify
6//! - CSR (Compressed Sparse Row): Efficient for row-wise operations and matrix-vector multiplication
7//! - CSC (Compressed Sparse Column): Efficient for column-wise operations
8//!
9//! Each format has different performance characteristics:
10//! - COO: Best for construction, random access, and format conversion
11//! - CSR: Best for row slicing, matrix-vector multiplication (A*x)
12//! - CSC: Best for column slicing, matrix-vector multiplication (A^T*x)
13
14use crate::{core_ops::Tensor, TensorElement};
15use scirs2_core::parallel_ops::*;
16use std::collections::HashMap;
17use torsh_core::{
18    device::DeviceType,
19    error::{Result, TorshError},
20}; // SciRS2 Parallel Operations for sparse tensor processing
21
22/// Sparse tensor in COO (Coordinate) format
23///
24/// COO format stores sparse tensors efficiently by only keeping track of non-zero elements
25/// and their coordinates. This is particularly useful for tensors with high sparsity ratios.
26///
27/// # Format
28/// - `indices`: N x D matrix where N is the number of non-zero elements and D is the number of dimensions
29/// - `values`: N-length vector containing the non-zero values
30/// - `shape`: The shape of the full dense tensor
31///
32/// # Example
33/// ```rust
34/// use torsh_tensor::sparse::SparseTensor;
35///
36/// // Create a 3x3 sparse matrix with values at (0,0)=1.0, (1,2)=2.0, (2,1)=3.0
37/// let indices = vec![vec![0, 0], vec![1, 2], vec![2, 1]];
38/// let values = vec![1.0, 2.0, 3.0];
39/// let shape = vec![3, 3];
40///
41/// let sparse = SparseTensor::from_coo(indices, values, shape).expect("sparse tensor creation should succeed");
42/// ```
43#[derive(Debug, Clone)]
44pub struct SparseTensor<T: TensorElement> {
45    /// Coordinates of non-zero elements (N x D matrix)
46    indices: Vec<Vec<usize>>,
47    /// Non-zero values (N elements)
48    values: Vec<T>,
49    /// Shape of the full dense tensor
50    shape: Vec<usize>,
51    /// Device where the tensor resides
52    device: DeviceType,
53    /// Number of non-zero elements
54    nnz: usize,
55}
56
57impl<T: TensorElement> SparseTensor<T> {
58    /// Create a new sparse tensor from COO format
59    ///
60    /// # Arguments
61    /// * `indices` - Coordinates of non-zero elements (each inner vector is one coordinate)
62    /// * `values` - Non-zero values corresponding to the indices
63    /// * `shape` - Shape of the full dense tensor
64    ///
65    /// # Returns
66    /// A new sparse tensor in COO format
67    ///
68    /// # Errors
69    /// Returns error if indices and values have mismatched lengths or if coordinates are out of bounds
70    pub fn from_coo(indices: Vec<Vec<usize>>, values: Vec<T>, shape: Vec<usize>) -> Result<Self> {
71        if indices.len() != values.len() {
72            return Err(TorshError::InvalidArgument(format!(
73                "Indices length ({}) must match values length ({})",
74                indices.len(),
75                values.len()
76            )));
77        }
78
79        let ndim = shape.len();
80        for (i, coord) in indices.iter().enumerate() {
81            if coord.len() != ndim {
82                return Err(TorshError::InvalidArgument(format!(
83                    "Index {} has {} dimensions, expected {}",
84                    i,
85                    coord.len(),
86                    ndim
87                )));
88            }
89
90            for (dim, &idx) in coord.iter().enumerate() {
91                if idx >= shape[dim] {
92                    return Err(TorshError::InvalidArgument(format!(
93                        "Index {} at dimension {} is out of bounds ({})",
94                        idx, dim, shape[dim]
95                    )));
96                }
97            }
98        }
99
100        Ok(Self {
101            nnz: indices.len(),
102            indices,
103            values,
104            shape,
105            device: DeviceType::Cpu,
106        })
107    }
108
109    /// Create a sparse tensor from a dense tensor by extracting non-zero elements
110    ///
111    /// # Arguments
112    /// * `dense` - The dense tensor to convert
113    /// * `tolerance` - Values with absolute value below this threshold are considered zero
114    ///
115    /// # Returns
116    /// A new sparse tensor containing only the non-zero elements
117    pub fn from_dense(dense: &Tensor<T>, tolerance: T) -> Result<Self>
118    where
119        T: Copy + PartialOrd + num_traits::Zero + num_traits::Signed,
120    {
121        let data = dense.data()?;
122        let shape = dense.shape().dims().to_vec();
123
124        let mut indices = Vec::new();
125        let mut values = Vec::new();
126
127        // Iterate through all elements and collect values above tolerance
128        for flat_idx in 0..data.len() {
129            let value = data[flat_idx];
130
131            // Check if absolute value exceeds tolerance threshold
132            let abs_value = value.abs();
133
134            // Only include values that exceed the tolerance threshold
135            if abs_value > tolerance {
136                // Convert flat index to multi-dimensional coordinates
137                let coords = Self::flat_to_coords(flat_idx, &shape);
138                indices.push(coords);
139                values.push(value);
140            }
141        }
142
143        Self::from_coo(indices, values, shape)
144    }
145
146    /// Convert this sparse tensor to a dense tensor
147    ///
148    /// # Returns
149    /// A dense tensor with zeros filled in for missing elements
150    pub fn to_dense(&self) -> Result<Tensor<T>>
151    where
152        T: Copy + num_traits::Zero,
153    {
154        let total_elements: usize = self.shape.iter().product();
155        let mut data = vec![<T as num_traits::Zero>::zero(); total_elements];
156
157        // Fill in the non-zero values
158        for (coords, &value) in self.indices.iter().zip(self.values.iter()) {
159            let flat_idx = Self::coords_to_flat(coords, &self.shape);
160            data[flat_idx] = value;
161        }
162
163        Tensor::from_data(data, self.shape.clone(), self.device)
164    }
165
166    /// Get the number of non-zero elements
167    pub fn nnz(&self) -> usize {
168        self.nnz
169    }
170
171    /// Get the shape of the sparse tensor
172    pub fn shape(&self) -> &[usize] {
173        &self.shape
174    }
175
176    /// Get the device where the tensor resides
177    pub fn device(&self) -> DeviceType {
178        self.device
179    }
180
181    /// Get the indices (coordinates) of non-zero elements
182    pub fn indices(&self) -> &[Vec<usize>] {
183        &self.indices
184    }
185
186    /// Get the non-zero values
187    pub fn values(&self) -> &[T] {
188        &self.values
189    }
190
191    /// Calculate the sparsity ratio (fraction of zero elements)
192    pub fn sparsity(&self) -> f64 {
193        let total_elements: usize = self.shape.iter().product();
194        1.0 - (self.nnz as f64 / total_elements as f64)
195    }
196
197    /// Get the memory footprint of the sparse representation in bytes
198    pub fn memory_usage(&self) -> usize {
199        let indices_size = self.indices.len() * self.shape.len() * std::mem::size_of::<usize>();
200        let values_size = self.values.len() * std::mem::size_of::<T>();
201        let shape_size = self.shape.len() * std::mem::size_of::<usize>();
202
203        indices_size + values_size + shape_size + std::mem::size_of::<Self>()
204    }
205
206    /// Compare memory usage with equivalent dense representation
207    pub fn memory_efficiency(&self) -> f64 {
208        let total_elements: usize = self.shape.iter().product();
209        let dense_size = total_elements * std::mem::size_of::<T>();
210        let sparse_size = self.memory_usage();
211
212        1.0 - (sparse_size as f64 / dense_size as f64)
213    }
214
215    /// Element-wise addition with another sparse tensor
216    ///
217    /// # Arguments
218    /// * `other` - The other sparse tensor to add
219    ///
220    /// # Returns
221    /// A new sparse tensor containing the sum
222    pub fn add(&self, other: &Self) -> Result<Self>
223    where
224        T: Copy + std::ops::Add<Output = T> + num_traits::Zero + PartialEq,
225    {
226        if self.shape != other.shape {
227            return Err(TorshError::InvalidArgument(format!(
228                "Shape mismatch: {:?} vs {:?}",
229                self.shape, other.shape
230            )));
231        }
232
233        // Use HashMap to efficiently merge non-zero elements
234        let mut result_map: HashMap<Vec<usize>, T> = HashMap::new();
235
236        // Add elements from first tensor
237        for (coords, &value) in self.indices.iter().zip(self.values.iter()) {
238            result_map.insert(coords.clone(), value);
239        }
240
241        // Add elements from second tensor
242        for (coords, &value) in other.indices.iter().zip(other.values.iter()) {
243            match result_map.get_mut(coords) {
244                Some(existing_value) => {
245                    *existing_value = *existing_value + value;
246                }
247                None => {
248                    result_map.insert(coords.clone(), value);
249                }
250            }
251        }
252
253        // Filter out zeros and collect results
254        let zero = <T as num_traits::Zero>::zero();
255        let mut indices = Vec::new();
256        let mut values = Vec::new();
257
258        for (coords, value) in result_map {
259            if value != zero {
260                indices.push(coords);
261                values.push(value);
262            }
263        }
264
265        Self::from_coo(indices, values, self.shape.clone())
266    }
267
268    /// Element-wise multiplication with another sparse tensor
269    ///
270    /// For sparse tensors, multiplication only produces non-zero results where both tensors
271    /// have non-zero elements at the same location.
272    pub fn mul(&self, other: &Self) -> Result<Self>
273    where
274        T: Copy + std::ops::Mul<Output = T> + num_traits::Zero + PartialEq,
275    {
276        if self.shape != other.shape {
277            return Err(TorshError::InvalidArgument(format!(
278                "Shape mismatch: {:?} vs {:?}",
279                self.shape, other.shape
280            )));
281        }
282
283        // Create HashMap for efficient lookup
284        let other_map: HashMap<Vec<usize>, T> = other
285            .indices
286            .iter()
287            .zip(other.values.iter())
288            .map(|(coords, &value)| (coords.clone(), value))
289            .collect();
290
291        let mut indices = Vec::new();
292        let mut values = Vec::new();
293        let zero = <T as num_traits::Zero>::zero();
294
295        // Only multiply where both tensors have non-zero elements
296        for (coords, &value) in self.indices.iter().zip(self.values.iter()) {
297            if let Some(&other_value) = other_map.get(coords) {
298                let result = value * other_value;
299                if result != zero {
300                    indices.push(coords.clone());
301                    values.push(result);
302                }
303            }
304        }
305
306        Self::from_coo(indices, values, self.shape.clone())
307    }
308
309    /// Scalar multiplication
310    pub fn mul_scalar(&self, scalar: T) -> Result<Self>
311    where
312        T: Copy + std::ops::Mul<Output = T> + num_traits::Zero + PartialEq,
313    {
314        let zero = <T as num_traits::Zero>::zero();
315        if scalar == zero {
316            // Result is all zeros - return empty sparse tensor
317            return Self::from_coo(Vec::new(), Vec::new(), self.shape.clone());
318        }
319
320        let new_values: Vec<T> = self.values.iter().map(|&v| v * scalar).collect();
321
322        Self::from_coo(self.indices.clone(), new_values, self.shape.clone())
323    }
324
325    /// Matrix multiplication for 2D sparse tensors
326    ///
327    /// # Arguments
328    /// * `other` - The other 2D sparse tensor to multiply with
329    ///
330    /// # Returns
331    /// A new sparse tensor containing the matrix product
332    pub fn matmul(&self, other: &Self) -> Result<Self>
333    where
334        T: Copy
335            + std::ops::Add<Output = T>
336            + std::ops::Mul<Output = T>
337            + num_traits::Zero
338            + PartialEq,
339    {
340        if self.shape.len() != 2 || other.shape.len() != 2 {
341            return Err(TorshError::InvalidArgument(
342                "Matrix multiplication requires 2D tensors".to_string(),
343            ));
344        }
345
346        if self.shape[1] != other.shape[0] {
347            return Err(TorshError::InvalidArgument(format!(
348                "Incompatible shapes for matmul: {:?} x {:?}",
349                self.shape, other.shape
350            )));
351        }
352
353        let m = self.shape[0];
354        let n = other.shape[1];
355        let k = self.shape[1];
356
357        // Sparse matrix multiplication uses only non-zero entries
358        // k (inner dimension) is validated for compatibility but not directly iterated
359        let _ = (m, k, n); // Use dimensions for validation
360
361        // Create efficient lookup structures
362        let mut left_rows: HashMap<usize, Vec<(usize, T)>> = HashMap::new();
363        let mut right_cols: HashMap<usize, Vec<(usize, T)>> = HashMap::new();
364
365        // Organize left matrix by rows
366        for (coords, &value) in self.indices.iter().zip(self.values.iter()) {
367            let row = coords[0];
368            let col = coords[1];
369            left_rows
370                .entry(row)
371                .or_insert_with(Vec::new)
372                .push((col, value));
373        }
374
375        // Organize right matrix by columns
376        for (coords, &value) in other.indices.iter().zip(other.values.iter()) {
377            let row = coords[0];
378            let col = coords[1];
379            right_cols
380                .entry(col)
381                .or_insert_with(Vec::new)
382                .push((row, value));
383        }
384
385        let mut result_map: HashMap<Vec<usize>, T> = HashMap::new();
386        let zero = <T as num_traits::Zero>::zero();
387
388        // Compute matrix multiplication
389        for (&row, left_row_data) in left_rows.iter() {
390            for (&col, right_col_data) in right_cols.iter() {
391                let mut sum = zero;
392
393                // Compute dot product of row and column
394                let mut left_iter = left_row_data.iter().peekable();
395                let mut right_iter = right_col_data.iter().peekable();
396
397                while let (Some(&(left_col, left_val)), Some(&(right_row, right_val))) =
398                    (left_iter.peek(), right_iter.peek())
399                {
400                    match left_col.cmp(&right_row) {
401                        std::cmp::Ordering::Equal => {
402                            sum = sum + (*left_val) * (*right_val);
403                            left_iter.next();
404                            right_iter.next();
405                        }
406                        std::cmp::Ordering::Less => {
407                            left_iter.next();
408                        }
409                        std::cmp::Ordering::Greater => {
410                            right_iter.next();
411                        }
412                    }
413                }
414
415                if sum != zero {
416                    result_map.insert(vec![row, col], sum);
417                }
418            }
419        }
420
421        // Convert result map to COO format
422        let mut indices = Vec::new();
423        let mut values = Vec::new();
424
425        for (coords, value) in result_map {
426            indices.push(coords);
427            values.push(value);
428        }
429
430        Self::from_coo(indices, values, vec![m, n])
431    }
432
433    /// Convert flat index to multi-dimensional coordinates
434    fn flat_to_coords(flat_idx: usize, shape: &[usize]) -> Vec<usize> {
435        let mut coords = vec![0; shape.len()];
436        let mut remaining = flat_idx;
437
438        for i in 0..shape.len() {
439            let stride: usize = shape[i + 1..].iter().product();
440            coords[i] = remaining / stride;
441            remaining %= stride;
442        }
443
444        coords
445    }
446
447    /// Convert multi-dimensional coordinates to flat index
448    fn coords_to_flat(coords: &[usize], shape: &[usize]) -> usize {
449        let mut flat_idx = 0;
450        let mut stride = 1;
451
452        for i in (0..coords.len()).rev() {
453            flat_idx += coords[i] * stride;
454            stride *= shape[i];
455        }
456
457        flat_idx
458    }
459
460    /// Transpose a 2D sparse tensor
461    pub fn transpose(&self) -> Result<Self>
462    where
463        T: Copy,
464    {
465        if self.shape.len() != 2 {
466            return Err(TorshError::InvalidArgument(
467                "Transpose is only supported for 2D tensors".to_string(),
468            ));
469        }
470
471        let new_shape = vec![self.shape[1], self.shape[0]];
472        let new_indices: Vec<Vec<usize>> = self
473            .indices
474            .iter()
475            .map(|coords| vec![coords[1], coords[0]])
476            .collect();
477
478        Self::from_coo(new_indices, self.values.clone(), new_shape)
479    }
480
481    /// Apply a function to all non-zero values
482    pub fn map<F>(&self, f: F) -> Result<Self>
483    where
484        F: Fn(T) -> T,
485        T: Copy + num_traits::Zero + PartialEq,
486    {
487        let new_values: Vec<T> = self.values.iter().map(|&v| f(v)).collect();
488
489        // Filter out any values that became zero
490        let zero = <T as num_traits::Zero>::zero();
491        let mut filtered_indices = Vec::new();
492        let mut filtered_values = Vec::new();
493
494        for (coords, &value) in self.indices.iter().zip(new_values.iter()) {
495            if value != zero {
496                filtered_indices.push(coords.clone());
497                filtered_values.push(value);
498            }
499        }
500
501        Self::from_coo(filtered_indices, filtered_values, self.shape.clone())
502    }
503
504    /// Check if the sparse tensor is structurally valid
505    pub fn is_valid(&self) -> bool {
506        // Check that indices and values have same length
507        if self.indices.len() != self.values.len() {
508            return false;
509        }
510
511        // Check that nnz matches actual length
512        if self.nnz != self.indices.len() {
513            return false;
514        }
515
516        // Check that all indices are within bounds
517        let ndim = self.shape.len();
518        for coords in &self.indices {
519            if coords.len() != ndim {
520                return false;
521            }
522
523            for (dim, &idx) in coords.iter().enumerate() {
524                if idx >= self.shape[dim] {
525                    return false;
526                }
527            }
528        }
529
530        true
531    }
532
533    /// Remove duplicate indices by summing their values
534    pub fn coalesce(&mut self) -> Result<()>
535    where
536        T: Copy + std::ops::AddAssign + num_traits::Zero + PartialEq,
537    {
538        if self.indices.is_empty() {
539            return Ok(());
540        }
541
542        let mut coord_map: HashMap<Vec<usize>, T> = HashMap::new();
543
544        // Sum values for duplicate coordinates
545        for (coords, &value) in self.indices.iter().zip(self.values.iter()) {
546            match coord_map.get_mut(coords) {
547                Some(existing_value) => {
548                    *existing_value += value;
549                }
550                None => {
551                    coord_map.insert(coords.clone(), value);
552                }
553            }
554        }
555
556        // Filter out zeros and rebuild indices/values
557        let zero = <T as num_traits::Zero>::zero();
558        let mut new_indices = Vec::new();
559        let mut new_values = Vec::new();
560
561        for (coords, value) in coord_map {
562            if value != zero {
563                new_indices.push(coords);
564                new_values.push(value);
565            }
566        }
567
568        self.indices = new_indices;
569        self.values = new_values;
570        self.nnz = self.indices.len();
571
572        Ok(())
573    }
574}
575
576/// Conversion utilities for sparse tensors
577impl<T: TensorElement> SparseTensor<T> {
578    /// Create a sparse identity matrix
579    pub fn eye(size: usize) -> Result<Self>
580    where
581        T: Copy + num_traits::One,
582    {
583        let mut indices = Vec::new();
584        let mut values = Vec::new();
585        let one = <T as num_traits::One>::one();
586
587        for i in 0..size {
588            indices.push(vec![i, i]);
589            values.push(one);
590        }
591
592        Self::from_coo(indices, values, vec![size, size])
593    }
594
595    /// Create a sparse tensor from triplets (row, col, value) for 2D case
596    pub fn from_triplets(
597        rows: Vec<usize>,
598        cols: Vec<usize>,
599        vals: Vec<T>,
600        shape: Vec<usize>,
601    ) -> Result<Self> {
602        if rows.len() != cols.len() || cols.len() != vals.len() {
603            return Err(TorshError::InvalidArgument(
604                "Rows, cols, and values must have the same length".to_string(),
605            ));
606        }
607
608        let indices: Vec<Vec<usize>> = rows
609            .into_iter()
610            .zip(cols.into_iter())
611            .map(|(r, c)| vec![r, c])
612            .collect();
613
614        Self::from_coo(indices, vals, shape)
615    }
616}
617
618/// Sparse tensor in CSR (Compressed Sparse Row) format
619///
620/// CSR format is optimized for row-wise operations and matrix-vector multiplication.
621/// It stores:
622/// - `row_ptr`: Array of size (num_rows + 1) indicating where each row starts
623/// - `col_indices`: Column indices for each non-zero value
624/// - `values`: Non-zero values in row-major order
625///
626/// # Example
627/// ```rust
628/// use torsh_tensor::sparse::SparseCSR;
629///
630/// // Create a 3x3 sparse matrix in CSR format
631/// // [[1.0, 0.0, 2.0],
632/// //  [0.0, 3.0, 0.0],
633/// //  [4.0, 0.0, 5.0]]
634/// let row_ptr = vec![0, 2, 3, 5];  // Row pointers
635/// let col_indices = vec![0, 2, 1, 0, 2];  // Column indices
636/// let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];  // Values
637/// let shape = vec![3, 3];
638///
639/// let sparse = SparseCSR::new(row_ptr, col_indices, values, shape).expect("CSR creation should succeed");
640/// ```
641#[derive(Debug, Clone)]
642pub struct SparseCSR<T: TensorElement> {
643    /// Row pointers (size: num_rows + 1)
644    row_ptr: Vec<usize>,
645    /// Column indices for each non-zero value
646    col_indices: Vec<usize>,
647    /// Non-zero values in row-major order
648    values: Vec<T>,
649    /// Shape of the full dense tensor [num_rows, num_cols]
650    shape: Vec<usize>,
651    /// Device where the tensor resides
652    device: DeviceType,
653    /// Number of non-zero elements
654    nnz: usize,
655}
656
657impl<T: TensorElement> SparseCSR<T> {
658    /// Create a new CSR sparse tensor
659    ///
660    /// # Arguments
661    /// * `row_ptr` - Row pointers (length: num_rows + 1)
662    /// * `col_indices` - Column indices for non-zero values
663    /// * `values` - Non-zero values
664    /// * `shape` - Shape [num_rows, num_cols]
665    pub fn new(
666        row_ptr: Vec<usize>,
667        col_indices: Vec<usize>,
668        values: Vec<T>,
669        shape: Vec<usize>,
670    ) -> Result<Self> {
671        if shape.len() != 2 {
672            return Err(TorshError::InvalidArgument(
673                "CSR format only supports 2D tensors".to_string(),
674            ));
675        }
676
677        if col_indices.len() != values.len() {
678            return Err(TorshError::InvalidArgument(format!(
679                "Column indices length ({}) must match values length ({})",
680                col_indices.len(),
681                values.len()
682            )));
683        }
684
685        if row_ptr.len() != shape[0] + 1 {
686            return Err(TorshError::InvalidArgument(format!(
687                "Row pointer length ({}) must be num_rows + 1 ({})",
688                row_ptr.len(),
689                shape[0] + 1
690            )));
691        }
692
693        // Validate row pointers are monotonically increasing
694        for i in 1..row_ptr.len() {
695            if row_ptr[i] < row_ptr[i - 1] {
696                return Err(TorshError::InvalidArgument(
697                    "Row pointers must be monotonically increasing".to_string(),
698                ));
699            }
700        }
701
702        // Validate column indices are within bounds
703        for &col_idx in &col_indices {
704            if col_idx >= shape[1] {
705                return Err(TorshError::InvalidArgument(format!(
706                    "Column index {} out of bounds for shape {:?}",
707                    col_idx, shape
708                )));
709            }
710        }
711
712        let nnz = values.len();
713        if row_ptr.last().copied().unwrap_or(0) != nnz {
714            return Err(TorshError::InvalidArgument(
715                "Last row pointer must equal number of non-zero values".to_string(),
716            ));
717        }
718
719        Ok(Self {
720            row_ptr,
721            col_indices,
722            values,
723            shape,
724            device: DeviceType::Cpu,
725            nnz,
726        })
727    }
728
729    /// Convert from COO format to CSR format
730    pub fn from_coo(coo: &SparseTensor<T>) -> Result<Self>
731    where
732        T: Copy,
733    {
734        if coo.shape().len() != 2 {
735            return Err(TorshError::InvalidArgument(
736                "CSR format only supports 2D tensors".to_string(),
737            ));
738        }
739
740        let num_rows = coo.shape()[0];
741        let num_cols = coo.shape()[1];
742
743        // Sort COO entries by row, then column
744        let mut entries: Vec<(usize, usize, T)> = coo
745            .indices()
746            .iter()
747            .zip(coo.values())
748            .map(|(coords, &val)| (coords[0], coords[1], val))
749            .collect();
750
751        entries.sort_by(|a, b| {
752            if a.0 == b.0 {
753                a.1.cmp(&b.1)
754            } else {
755                a.0.cmp(&b.0)
756            }
757        });
758
759        // Build CSR structure
760        let mut row_ptr = vec![0; num_rows + 1];
761        let mut col_indices = Vec::with_capacity(entries.len());
762        let mut values = Vec::with_capacity(entries.len());
763
764        for (row, col, val) in entries {
765            col_indices.push(col);
766            values.push(val);
767            row_ptr[row + 1] += 1;
768        }
769
770        // Convert counts to cumulative sum
771        for i in 1..=num_rows {
772            row_ptr[i] += row_ptr[i - 1];
773        }
774
775        Self::new(row_ptr, col_indices, values, vec![num_rows, num_cols])
776    }
777
778    /// Convert to COO format
779    pub fn to_coo(&self) -> Result<SparseTensor<T>>
780    where
781        T: Copy,
782    {
783        let mut indices = Vec::new();
784        let mut values = Vec::new();
785
786        for row in 0..self.shape[0] {
787            let start = self.row_ptr[row];
788            let end = self.row_ptr[row + 1];
789
790            for idx in start..end {
791                indices.push(vec![row, self.col_indices[idx]]);
792                values.push(self.values[idx]);
793            }
794        }
795
796        SparseTensor::from_coo(indices, values, self.shape.clone())
797    }
798
799    /// Convert to dense tensor
800    pub fn to_dense(&self) -> Result<Tensor<T>>
801    where
802        T: Copy + num_traits::Zero,
803    {
804        let total_elements = self.shape[0] * self.shape[1];
805        let mut data = vec![<T as num_traits::Zero>::zero(); total_elements];
806
807        for row in 0..self.shape[0] {
808            let start = self.row_ptr[row];
809            let end = self.row_ptr[row + 1];
810
811            for idx in start..end {
812                let col = self.col_indices[idx];
813                let flat_idx = row * self.shape[1] + col;
814                data[flat_idx] = self.values[idx];
815            }
816        }
817
818        Tensor::from_data(data, self.shape.clone(), self.device)
819    }
820
821    /// Matrix-vector multiplication (optimized for CSR)
822    pub fn matvec(&self, vec: &[T]) -> Result<Vec<T>>
823    where
824        T: Copy + std::ops::Add<Output = T> + std::ops::Mul<Output = T> + num_traits::Zero,
825    {
826        if vec.len() != self.shape[1] {
827            return Err(TorshError::InvalidArgument(format!(
828                "Vector length ({}) must match number of columns ({})",
829                vec.len(),
830                self.shape[1]
831            )));
832        }
833
834        let mut result = vec![<T as num_traits::Zero>::zero(); self.shape[0]];
835
836        // Parallel row-wise computation using SciRS2
837        result
838            .par_iter_mut()
839            .enumerate()
840            .for_each(|(row, result_val)| {
841                let start = self.row_ptr[row];
842                let end = self.row_ptr[row + 1];
843                let mut sum = <T as num_traits::Zero>::zero();
844
845                for idx in start..end {
846                    let col = self.col_indices[idx];
847                    sum = sum + self.values[idx] * vec[col];
848                }
849
850                *result_val = sum;
851            });
852
853        Ok(result)
854    }
855
856    /// Get a specific row as a sparse vector
857    pub fn get_row(&self, row: usize) -> Result<(Vec<usize>, Vec<T>)>
858    where
859        T: Copy,
860    {
861        if row >= self.shape[0] {
862            return Err(TorshError::InvalidArgument(format!(
863                "Row {} out of bounds for shape {:?}",
864                row, self.shape
865            )));
866        }
867
868        let start = self.row_ptr[row];
869        let end = self.row_ptr[row + 1];
870
871        let col_indices = self.col_indices[start..end].to_vec();
872        let values = self.values[start..end].to_vec();
873
874        Ok((col_indices, values))
875    }
876
877    /// Getters
878    pub fn nnz(&self) -> usize {
879        self.nnz
880    }
881
882    pub fn shape(&self) -> &[usize] {
883        &self.shape
884    }
885
886    pub fn device(&self) -> DeviceType {
887        self.device
888    }
889
890    pub fn row_ptr(&self) -> &[usize] {
891        &self.row_ptr
892    }
893
894    pub fn col_indices(&self) -> &[usize] {
895        &self.col_indices
896    }
897
898    pub fn values(&self) -> &[T] {
899        &self.values
900    }
901}
902
903/// Sparse tensor in CSC (Compressed Sparse Column) format
904///
905/// CSC format is optimized for column-wise operations.
906/// It stores:
907/// - `col_ptr`: Array of size (num_cols + 1) indicating where each column starts
908/// - `row_indices`: Row indices for each non-zero value
909/// - `values`: Non-zero values in column-major order
910///
911/// # Example
912/// ```rust
913/// use torsh_tensor::sparse::SparseCSC;
914///
915/// // Create a 3x3 sparse matrix in CSC format
916/// // [[1.0, 0.0, 2.0],
917/// //  [0.0, 3.0, 0.0],
918/// //  [4.0, 0.0, 5.0]]
919/// let col_ptr = vec![0, 2, 3, 5];  // Column pointers
920/// let row_indices = vec![0, 2, 1, 0, 2];  // Row indices
921/// let values = vec![1.0, 4.0, 3.0, 2.0, 5.0];  // Values
922/// let shape = vec![3, 3];
923///
924/// let sparse = SparseCSC::new(col_ptr, row_indices, values, shape).expect("CSC creation should succeed");
925/// ```
926#[derive(Debug, Clone)]
927pub struct SparseCSC<T: TensorElement> {
928    /// Column pointers (size: num_cols + 1)
929    col_ptr: Vec<usize>,
930    /// Row indices for each non-zero value
931    row_indices: Vec<usize>,
932    /// Non-zero values in column-major order
933    values: Vec<T>,
934    /// Shape of the full dense tensor [num_rows, num_cols]
935    shape: Vec<usize>,
936    /// Device where the tensor resides
937    device: DeviceType,
938    /// Number of non-zero elements
939    nnz: usize,
940}
941
942impl<T: TensorElement> SparseCSC<T> {
943    /// Create a new CSC sparse tensor
944    ///
945    /// # Arguments
946    /// * `col_ptr` - Column pointers (length: num_cols + 1)
947    /// * `row_indices` - Row indices for non-zero values
948    /// * `values` - Non-zero values
949    /// * `shape` - Shape [num_rows, num_cols]
950    pub fn new(
951        col_ptr: Vec<usize>,
952        row_indices: Vec<usize>,
953        values: Vec<T>,
954        shape: Vec<usize>,
955    ) -> Result<Self> {
956        if shape.len() != 2 {
957            return Err(TorshError::InvalidArgument(
958                "CSC format only supports 2D tensors".to_string(),
959            ));
960        }
961
962        if row_indices.len() != values.len() {
963            return Err(TorshError::InvalidArgument(format!(
964                "Row indices length ({}) must match values length ({})",
965                row_indices.len(),
966                values.len()
967            )));
968        }
969
970        if col_ptr.len() != shape[1] + 1 {
971            return Err(TorshError::InvalidArgument(format!(
972                "Column pointer length ({}) must be num_cols + 1 ({})",
973                col_ptr.len(),
974                shape[1] + 1
975            )));
976        }
977
978        // Validate column pointers are monotonically increasing
979        for i in 1..col_ptr.len() {
980            if col_ptr[i] < col_ptr[i - 1] {
981                return Err(TorshError::InvalidArgument(
982                    "Column pointers must be monotonically increasing".to_string(),
983                ));
984            }
985        }
986
987        // Validate row indices are within bounds
988        for &row_idx in &row_indices {
989            if row_idx >= shape[0] {
990                return Err(TorshError::InvalidArgument(format!(
991                    "Row index {} out of bounds for shape {:?}",
992                    row_idx, shape
993                )));
994            }
995        }
996
997        let nnz = values.len();
998        if col_ptr.last().copied().unwrap_or(0) != nnz {
999            return Err(TorshError::InvalidArgument(
1000                "Last column pointer must equal number of non-zero values".to_string(),
1001            ));
1002        }
1003
1004        Ok(Self {
1005            col_ptr,
1006            row_indices,
1007            values,
1008            shape,
1009            device: DeviceType::Cpu,
1010            nnz,
1011        })
1012    }
1013
1014    /// Convert from COO format to CSC format
1015    pub fn from_coo(coo: &SparseTensor<T>) -> Result<Self>
1016    where
1017        T: Copy,
1018    {
1019        if coo.shape().len() != 2 {
1020            return Err(TorshError::InvalidArgument(
1021                "CSC format only supports 2D tensors".to_string(),
1022            ));
1023        }
1024
1025        let num_rows = coo.shape()[0];
1026        let num_cols = coo.shape()[1];
1027
1028        // Sort COO entries by column, then row
1029        let mut entries: Vec<(usize, usize, T)> = coo
1030            .indices()
1031            .iter()
1032            .zip(coo.values())
1033            .map(|(coords, &val)| (coords[0], coords[1], val))
1034            .collect();
1035
1036        entries.sort_by(|a, b| {
1037            if a.1 == b.1 {
1038                a.0.cmp(&b.0)
1039            } else {
1040                a.1.cmp(&b.1)
1041            }
1042        });
1043
1044        // Build CSC structure
1045        let mut col_ptr = vec![0; num_cols + 1];
1046        let mut row_indices = Vec::with_capacity(entries.len());
1047        let mut values = Vec::with_capacity(entries.len());
1048
1049        for (row, col, val) in entries {
1050            row_indices.push(row);
1051            values.push(val);
1052            col_ptr[col + 1] += 1;
1053        }
1054
1055        // Convert counts to cumulative sum
1056        for i in 1..=num_cols {
1057            col_ptr[i] += col_ptr[i - 1];
1058        }
1059
1060        Self::new(col_ptr, row_indices, values, vec![num_rows, num_cols])
1061    }
1062
1063    /// Convert to COO format
1064    pub fn to_coo(&self) -> Result<SparseTensor<T>>
1065    where
1066        T: Copy,
1067    {
1068        let mut indices = Vec::new();
1069        let mut values = Vec::new();
1070
1071        for col in 0..self.shape[1] {
1072            let start = self.col_ptr[col];
1073            let end = self.col_ptr[col + 1];
1074
1075            for idx in start..end {
1076                indices.push(vec![self.row_indices[idx], col]);
1077                values.push(self.values[idx]);
1078            }
1079        }
1080
1081        SparseTensor::from_coo(indices, values, self.shape.clone())
1082    }
1083
1084    /// Convert to dense tensor
1085    pub fn to_dense(&self) -> Result<Tensor<T>>
1086    where
1087        T: Copy + num_traits::Zero,
1088    {
1089        let total_elements = self.shape[0] * self.shape[1];
1090        let mut data = vec![<T as num_traits::Zero>::zero(); total_elements];
1091
1092        for col in 0..self.shape[1] {
1093            let start = self.col_ptr[col];
1094            let end = self.col_ptr[col + 1];
1095
1096            for idx in start..end {
1097                let row = self.row_indices[idx];
1098                let flat_idx = row * self.shape[1] + col;
1099                data[flat_idx] = self.values[idx];
1100            }
1101        }
1102
1103        Tensor::from_data(data, self.shape.clone(), self.device)
1104    }
1105
1106    /// Matrix-vector multiplication with transposed matrix (A^T * v) - optimized for CSC
1107    pub fn transpose_matvec(&self, vec: &[T]) -> Result<Vec<T>>
1108    where
1109        T: Copy + std::ops::Add<Output = T> + std::ops::Mul<Output = T> + num_traits::Zero,
1110    {
1111        if vec.len() != self.shape[0] {
1112            return Err(TorshError::InvalidArgument(format!(
1113                "Vector length ({}) must match number of rows ({})",
1114                vec.len(),
1115                self.shape[0]
1116            )));
1117        }
1118
1119        let mut result = vec![<T as num_traits::Zero>::zero(); self.shape[1]];
1120
1121        // Parallel column-wise computation using SciRS2
1122        result
1123            .par_iter_mut()
1124            .enumerate()
1125            .for_each(|(col, result_val)| {
1126                let start = self.col_ptr[col];
1127                let end = self.col_ptr[col + 1];
1128                let mut sum = <T as num_traits::Zero>::zero();
1129
1130                for idx in start..end {
1131                    let row = self.row_indices[idx];
1132                    sum = sum + self.values[idx] * vec[row];
1133                }
1134
1135                *result_val = sum;
1136            });
1137
1138        Ok(result)
1139    }
1140
1141    /// Get a specific column as a sparse vector
1142    pub fn get_col(&self, col: usize) -> Result<(Vec<usize>, Vec<T>)>
1143    where
1144        T: Copy,
1145    {
1146        if col >= self.shape[1] {
1147            return Err(TorshError::InvalidArgument(format!(
1148                "Column {} out of bounds for shape {:?}",
1149                col, self.shape
1150            )));
1151        }
1152
1153        let start = self.col_ptr[col];
1154        let end = self.col_ptr[col + 1];
1155
1156        let row_indices = self.row_indices[start..end].to_vec();
1157        let values = self.values[start..end].to_vec();
1158
1159        Ok((row_indices, values))
1160    }
1161
1162    /// Getters
1163    pub fn nnz(&self) -> usize {
1164        self.nnz
1165    }
1166
1167    pub fn shape(&self) -> &[usize] {
1168        &self.shape
1169    }
1170
1171    pub fn device(&self) -> DeviceType {
1172        self.device
1173    }
1174
1175    pub fn col_ptr(&self) -> &[usize] {
1176        &self.col_ptr
1177    }
1178
1179    pub fn row_indices(&self) -> &[usize] {
1180        &self.row_indices
1181    }
1182
1183    pub fn values(&self) -> &[T] {
1184        &self.values
1185    }
1186}
1187
1188#[cfg(test)]
1189mod tests {
1190    use super::*;
1191    use torsh_core::device::DeviceType;
1192
1193    #[test]
1194    fn test_sparse_tensor_creation() {
1195        let indices = vec![vec![0, 0], vec![1, 2], vec![2, 1]];
1196        let values = vec![1.0, 2.0, 3.0];
1197        let shape = vec![3, 3];
1198
1199        let sparse = SparseTensor::from_coo(indices, values, shape)
1200            .expect("COO sparse tensor creation should succeed");
1201        assert_eq!(sparse.nnz(), 3);
1202        assert_eq!(sparse.shape(), &[3, 3]);
1203        assert!(sparse.sparsity() > 0.6); // 6 out of 9 elements are zero
1204    }
1205
1206    #[test]
1207    fn test_sparse_to_dense_conversion() {
1208        let indices = vec![vec![0, 0], vec![1, 1], vec![2, 2]];
1209        let values = vec![1.0, 2.0, 3.0];
1210        let shape = vec![3, 3];
1211
1212        let sparse = SparseTensor::from_coo(indices, values, shape)
1213            .expect("COO sparse tensor creation should succeed");
1214        let dense = sparse
1215            .to_dense()
1216            .expect("sparse to dense conversion should succeed");
1217
1218        let expected_data = vec![1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0];
1219
1220        assert_eq!(
1221            dense.data().expect("data access should succeed"),
1222            expected_data
1223        );
1224    }
1225
1226    #[test]
1227    fn test_sparse_addition() {
1228        let indices1 = vec![vec![0, 0], vec![1, 1]];
1229        let values1 = vec![1.0, 2.0];
1230        let shape = vec![3, 3];
1231        let sparse1 = SparseTensor::from_coo(indices1, values1, shape.clone())
1232            .expect("COO sparse tensor creation should succeed");
1233
1234        let indices2 = vec![vec![0, 0], vec![2, 2]];
1235        let values2 = vec![3.0, 4.0];
1236        let sparse2 = SparseTensor::from_coo(indices2, values2, shape)
1237            .expect("COO sparse tensor creation should succeed");
1238
1239        let result = sparse1
1240            .add(&sparse2)
1241            .expect("sparse addition should succeed");
1242
1243        // Should have (0,0)=4.0, (1,1)=2.0, (2,2)=4.0
1244        assert_eq!(result.nnz(), 3);
1245
1246        let dense_result = result
1247            .to_dense()
1248            .expect("sparse to dense conversion should succeed");
1249        let expected = vec![4.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 4.0];
1250        assert_eq!(
1251            dense_result.data().expect("data access should succeed"),
1252            expected
1253        );
1254    }
1255
1256    #[test]
1257    fn test_sparse_multiplication() {
1258        let indices1 = vec![vec![0, 0], vec![1, 1], vec![2, 2]];
1259        let values1 = vec![2.0, 3.0, 4.0];
1260        let shape = vec![3, 3];
1261        let sparse1 = SparseTensor::from_coo(indices1, values1, shape.clone())
1262            .expect("COO sparse tensor creation should succeed");
1263
1264        let indices2 = vec![vec![0, 0], vec![1, 1]];
1265        let values2 = vec![5.0, 6.0];
1266        let sparse2 = SparseTensor::from_coo(indices2, values2, shape)
1267            .expect("COO sparse tensor creation should succeed");
1268
1269        let result = sparse1
1270            .mul(&sparse2)
1271            .expect("sparse element-wise multiplication should succeed");
1272
1273        // Should have (0,0)=10.0, (1,1)=18.0
1274        assert_eq!(result.nnz(), 2);
1275
1276        let dense_result = result
1277            .to_dense()
1278            .expect("sparse to dense conversion should succeed");
1279        let expected = vec![10.0, 0.0, 0.0, 0.0, 18.0, 0.0, 0.0, 0.0, 0.0];
1280        assert_eq!(
1281            dense_result.data().expect("data access should succeed"),
1282            expected
1283        );
1284    }
1285
1286    #[test]
1287    fn test_sparse_matmul() {
1288        // Create a 2x2 sparse matrix [[1, 0], [0, 2]]
1289        let indices1 = vec![vec![0, 0], vec![1, 1]];
1290        let values1 = vec![1.0, 2.0];
1291        let shape1 = vec![2, 2];
1292        let sparse1 = SparseTensor::from_coo(indices1, values1, shape1)
1293            .expect("COO sparse tensor creation should succeed");
1294
1295        // Create a 2x2 sparse matrix [[3, 0], [0, 4]]
1296        let indices2 = vec![vec![0, 0], vec![1, 1]];
1297        let values2 = vec![3.0, 4.0];
1298        let shape2 = vec![2, 2];
1299        let sparse2 = SparseTensor::from_coo(indices2, values2, shape2)
1300            .expect("COO sparse tensor creation should succeed");
1301
1302        let result = sparse1
1303            .matmul(&sparse2)
1304            .expect("sparse matrix multiplication should succeed");
1305
1306        // Result should be [[3, 0], [0, 8]]
1307        assert_eq!(result.nnz(), 2);
1308
1309        let dense_result = result
1310            .to_dense()
1311            .expect("sparse to dense conversion should succeed");
1312        let expected = vec![3.0, 0.0, 0.0, 8.0];
1313        assert_eq!(
1314            dense_result.data().expect("data access should succeed"),
1315            expected
1316        );
1317    }
1318
1319    #[test]
1320    fn test_sparse_transpose() {
1321        let indices = vec![vec![0, 1], vec![1, 0], vec![2, 1]];
1322        let values = vec![1.0, 2.0, 3.0];
1323        let shape = vec![3, 2];
1324        let sparse = SparseTensor::from_coo(indices, values, shape)
1325            .expect("COO sparse tensor creation should succeed");
1326
1327        let transposed = sparse.transpose().expect("sparse transpose should succeed");
1328        assert_eq!(transposed.shape(), &[2, 3]);
1329
1330        let dense_transposed = transposed
1331            .to_dense()
1332            .expect("sparse to dense conversion should succeed");
1333        let expected = vec![0.0, 2.0, 0.0, 1.0, 0.0, 3.0];
1334        assert_eq!(
1335            dense_transposed.data().expect("data access should succeed"),
1336            expected
1337        );
1338    }
1339
1340    #[test]
1341    fn test_sparse_identity() {
1342        let eye = SparseTensor::<f32>::eye(3).expect("sparse identity creation should succeed");
1343        assert_eq!(eye.nnz(), 3);
1344        assert_eq!(eye.shape(), &[3, 3]);
1345
1346        let dense_eye = eye
1347            .to_dense()
1348            .expect("sparse to dense conversion should succeed");
1349        let expected = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0];
1350        assert_eq!(
1351            dense_eye.data().expect("data access should succeed"),
1352            expected
1353        );
1354    }
1355
1356    #[test]
1357    fn test_memory_efficiency() {
1358        let indices = vec![vec![0, 0]]; // Only one non-zero element
1359        let values = vec![1.0];
1360        let shape = vec![1000, 1000]; // Large tensor
1361        let sparse = SparseTensor::from_coo(indices, values, shape)
1362            .expect("COO sparse tensor creation should succeed");
1363
1364        assert!(sparse.sparsity() > 0.999); // Very sparse
1365        assert!(sparse.memory_efficiency() > 0.9); // Much more memory efficient
1366    }
1367
1368    #[test]
1369    fn test_from_dense() {
1370        let data = vec![1.0, 0.0, 0.0, 0.0, 2.0, 0.0];
1371        let dense = Tensor::from_data(data, vec![2, 3], DeviceType::Cpu)
1372            .expect("tensor creation should succeed");
1373
1374        let sparse =
1375            SparseTensor::from_dense(&dense, 1e-6).expect("from_dense conversion should succeed");
1376        assert_eq!(sparse.nnz(), 2);
1377
1378        let back_to_dense = sparse
1379            .to_dense()
1380            .expect("sparse to dense conversion should succeed");
1381        assert_eq!(
1382            dense.data().expect("data access should succeed"),
1383            back_to_dense.data().expect("data access should succeed")
1384        );
1385    }
1386
1387    #[test]
1388    fn test_coalesce() {
1389        // Create sparse tensor with duplicate indices
1390        let indices = vec![vec![0, 0], vec![1, 1], vec![0, 0]]; // (0,0) appears twice
1391        let values = vec![1.0, 2.0, 3.0];
1392        let shape = vec![2, 2];
1393
1394        let mut sparse = SparseTensor::from_coo(indices, values, shape)
1395            .expect("COO sparse tensor creation should succeed");
1396        assert_eq!(sparse.nnz(), 3);
1397
1398        sparse.coalesce().expect("coalesce should succeed");
1399        assert_eq!(sparse.nnz(), 2); // Should have combined duplicates
1400
1401        let dense = sparse
1402            .to_dense()
1403            .expect("sparse to dense conversion should succeed");
1404        let expected = vec![4.0, 0.0, 0.0, 2.0]; // (0,0) = 1+3 = 4
1405        assert_eq!(dense.data().expect("data access should succeed"), expected);
1406    }
1407
1408    #[test]
1409    fn test_scalar_multiplication() {
1410        let indices = vec![vec![0, 0], vec![1, 1]];
1411        let values = vec![2.0, 3.0];
1412        let shape = vec![2, 2];
1413        let sparse = SparseTensor::from_coo(indices, values, shape)
1414            .expect("COO sparse tensor creation should succeed");
1415
1416        let result = sparse
1417            .mul_scalar(2.0)
1418            .expect("scalar multiplication should succeed");
1419        assert_eq!(result.nnz(), 2);
1420
1421        let dense_result = result
1422            .to_dense()
1423            .expect("sparse to dense conversion should succeed");
1424        let expected = vec![4.0, 0.0, 0.0, 6.0];
1425        assert_eq!(
1426            dense_result.data().expect("data access should succeed"),
1427            expected
1428        );
1429    }
1430
1431    #[test]
1432    fn test_map_function() {
1433        let indices = vec![vec![0, 0], vec![1, 1]];
1434        let values = vec![2.0, 3.0];
1435        let shape = vec![2, 2];
1436        let sparse = SparseTensor::from_coo(indices, values, shape)
1437            .expect("COO sparse tensor creation should succeed");
1438
1439        let result = sparse.map(|x| x * x).expect("map operation should succeed"); // Square all values
1440        assert_eq!(result.nnz(), 2);
1441
1442        let dense_result = result
1443            .to_dense()
1444            .expect("sparse to dense conversion should succeed");
1445        let expected = vec![4.0, 0.0, 0.0, 9.0];
1446        assert_eq!(
1447            dense_result.data().expect("data access should succeed"),
1448            expected
1449        );
1450    }
1451
1452    #[test]
1453    fn test_error_cases() {
1454        // Mismatched indices and values length
1455        let indices = vec![vec![0, 0]];
1456        let values = vec![1.0, 2.0]; // Different length
1457        let shape = vec![2, 2];
1458        assert!(SparseTensor::from_coo(indices, values, shape).is_err());
1459
1460        // Out of bounds indices
1461        let indices = vec![vec![2, 0]]; // Row 2 is out of bounds for 2x2 matrix
1462        let values = vec![1.0];
1463        let shape = vec![2, 2];
1464        assert!(SparseTensor::from_coo(indices, values, shape).is_err());
1465    }
1466
1467    // Tests for CSR format
1468    #[test]
1469    fn test_csr_creation() {
1470        // Create CSR matrix:
1471        // [[1.0, 0.0, 2.0],
1472        //  [0.0, 3.0, 0.0],
1473        //  [4.0, 0.0, 5.0]]
1474        let row_ptr = vec![0, 2, 3, 5];
1475        let col_indices = vec![0, 2, 1, 0, 2];
1476        let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1477        let shape = vec![3, 3];
1478
1479        let sparse = SparseCSR::new(row_ptr, col_indices, values, shape)
1480            .expect("CSR creation should succeed");
1481        assert_eq!(sparse.nnz(), 5);
1482        assert_eq!(sparse.shape(), &[3, 3]);
1483    }
1484
1485    #[test]
1486    fn test_csr_to_dense() {
1487        let row_ptr = vec![0, 2, 3, 5];
1488        let col_indices = vec![0, 2, 1, 0, 2];
1489        let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1490        let shape = vec![3, 3];
1491
1492        let sparse = SparseCSR::new(row_ptr, col_indices, values, shape)
1493            .expect("CSR creation should succeed");
1494        let dense = sparse
1495            .to_dense()
1496            .expect("sparse to dense conversion should succeed");
1497
1498        let expected = vec![1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0, 5.0];
1499        assert_eq!(dense.data().expect("data access should succeed"), expected);
1500    }
1501
1502    #[test]
1503    fn test_csr_from_coo() {
1504        let indices = vec![vec![0, 0], vec![0, 2], vec![1, 1], vec![2, 0], vec![2, 2]];
1505        let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1506        let shape = vec![3, 3];
1507
1508        let coo = SparseTensor::from_coo(indices, values, shape)
1509            .expect("COO sparse tensor creation should succeed");
1510        let csr = SparseCSR::from_coo(&coo).expect("COO sparse tensor creation should succeed");
1511
1512        assert_eq!(csr.nnz(), 5);
1513        assert_eq!(csr.shape(), &[3, 3]);
1514
1515        // Verify CSR structure
1516        assert_eq!(csr.row_ptr(), &[0, 2, 3, 5]);
1517        assert_eq!(csr.col_indices(), &[0, 2, 1, 0, 2]);
1518    }
1519
1520    #[test]
1521    fn test_csr_matvec() {
1522        // Matrix: [[1.0, 0.0], [0.0, 2.0]]
1523        let row_ptr = vec![0, 1, 2];
1524        let col_indices = vec![0, 1];
1525        let values = vec![1.0, 2.0];
1526        let shape = vec![2, 2];
1527
1528        let sparse = SparseCSR::new(row_ptr, col_indices, values, shape)
1529            .expect("CSR creation should succeed");
1530        let vec = vec![3.0, 4.0];
1531        let result = sparse
1532            .matvec(&vec)
1533            .expect("matrix-vector multiplication should succeed");
1534
1535        // Expected: [1*3, 2*4] = [3.0, 8.0]
1536        assert_eq!(result, vec![3.0, 8.0]);
1537    }
1538
1539    #[test]
1540    fn test_csr_get_row() {
1541        let row_ptr = vec![0, 2, 3, 5];
1542        let col_indices = vec![0, 2, 1, 0, 2];
1543        let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1544        let shape = vec![3, 3];
1545
1546        let sparse = SparseCSR::new(row_ptr, col_indices, values, shape)
1547            .expect("CSR creation should succeed");
1548
1549        // Get row 0
1550        let (cols, vals) = sparse.get_row(0).expect("row access should succeed");
1551        assert_eq!(cols, vec![0, 2]);
1552        assert_eq!(vals, vec![1.0, 2.0]);
1553
1554        // Get row 1
1555        let (cols, vals) = sparse.get_row(1).expect("row access should succeed");
1556        assert_eq!(cols, vec![1]);
1557        assert_eq!(vals, vec![3.0]);
1558    }
1559
1560    #[test]
1561    fn test_csr_to_coo() {
1562        let row_ptr = vec![0, 2, 3, 5];
1563        let col_indices = vec![0, 2, 1, 0, 2];
1564        let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1565        let shape = vec![3, 3];
1566
1567        let csr = SparseCSR::new(row_ptr, col_indices, values, shape)
1568            .expect("CSR creation should succeed");
1569        let coo = csr.to_coo().expect("to COO conversion should succeed");
1570
1571        assert_eq!(coo.nnz(), 5);
1572        let dense_coo = coo
1573            .to_dense()
1574            .expect("sparse to dense conversion should succeed");
1575        let dense_csr = csr
1576            .to_dense()
1577            .expect("sparse to dense conversion should succeed");
1578        assert_eq!(
1579            dense_coo.data().expect("data access should succeed"),
1580            dense_csr.data().expect("data access should succeed")
1581        );
1582    }
1583
1584    // Tests for CSC format
1585    #[test]
1586    fn test_csc_creation() {
1587        // Create CSC matrix:
1588        // [[1.0, 0.0, 2.0],
1589        //  [0.0, 3.0, 0.0],
1590        //  [4.0, 0.0, 5.0]]
1591        let col_ptr = vec![0, 2, 3, 5];
1592        let row_indices = vec![0, 2, 1, 0, 2];
1593        let values = vec![1.0, 4.0, 3.0, 2.0, 5.0];
1594        let shape = vec![3, 3];
1595
1596        let sparse = SparseCSC::new(col_ptr, row_indices, values, shape)
1597            .expect("CSC creation should succeed");
1598        assert_eq!(sparse.nnz(), 5);
1599        assert_eq!(sparse.shape(), &[3, 3]);
1600    }
1601
1602    #[test]
1603    fn test_csc_to_dense() {
1604        let col_ptr = vec![0, 2, 3, 5];
1605        let row_indices = vec![0, 2, 1, 0, 2];
1606        let values = vec![1.0, 4.0, 3.0, 2.0, 5.0];
1607        let shape = vec![3, 3];
1608
1609        let sparse = SparseCSC::new(col_ptr, row_indices, values, shape)
1610            .expect("CSC creation should succeed");
1611        let dense = sparse
1612            .to_dense()
1613            .expect("sparse to dense conversion should succeed");
1614
1615        let expected = vec![1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0, 5.0];
1616        assert_eq!(dense.data().expect("data access should succeed"), expected);
1617    }
1618
1619    #[test]
1620    fn test_csc_from_coo() {
1621        let indices = vec![vec![0, 0], vec![0, 2], vec![1, 1], vec![2, 0], vec![2, 2]];
1622        let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1623        let shape = vec![3, 3];
1624
1625        let coo = SparseTensor::from_coo(indices, values, shape)
1626            .expect("COO sparse tensor creation should succeed");
1627        let csc = SparseCSC::from_coo(&coo).expect("COO sparse tensor creation should succeed");
1628
1629        assert_eq!(csc.nnz(), 5);
1630        assert_eq!(csc.shape(), &[3, 3]);
1631
1632        // Verify CSC structure
1633        assert_eq!(csc.col_ptr(), &[0, 2, 3, 5]);
1634        assert_eq!(csc.row_indices(), &[0, 2, 1, 0, 2]);
1635    }
1636
1637    #[test]
1638    fn test_csc_transpose_matvec() {
1639        // Matrix: [[1.0, 0.0], [0.0, 2.0]]
1640        // CSC stores column-wise
1641        let col_ptr = vec![0, 1, 2];
1642        let row_indices = vec![0, 1];
1643        let values = vec![1.0, 2.0];
1644        let shape = vec![2, 2];
1645
1646        let sparse = SparseCSC::new(col_ptr, row_indices, values, shape)
1647            .expect("CSC creation should succeed");
1648        let vec = vec![3.0, 4.0];
1649        let result = sparse
1650            .transpose_matvec(&vec)
1651            .expect("transpose matrix-vector multiplication should succeed");
1652
1653        // Expected: A^T * [3, 4] = [[1, 0], [0, 2]]^T * [3, 4] = [3.0, 8.0]
1654        assert_eq!(result, vec![3.0, 8.0]);
1655    }
1656
1657    #[test]
1658    fn test_csc_get_col() {
1659        let col_ptr = vec![0, 2, 3, 5];
1660        let row_indices = vec![0, 2, 1, 0, 2];
1661        let values = vec![1.0, 4.0, 3.0, 2.0, 5.0];
1662        let shape = vec![3, 3];
1663
1664        let sparse = SparseCSC::new(col_ptr, row_indices, values, shape)
1665            .expect("CSC creation should succeed");
1666
1667        // Get column 0
1668        let (rows, vals) = sparse.get_col(0).expect("column access should succeed");
1669        assert_eq!(rows, vec![0, 2]);
1670        assert_eq!(vals, vec![1.0, 4.0]);
1671
1672        // Get column 1
1673        let (rows, vals) = sparse.get_col(1).expect("column access should succeed");
1674        assert_eq!(rows, vec![1]);
1675        assert_eq!(vals, vec![3.0]);
1676    }
1677
1678    #[test]
1679    fn test_csc_to_coo() {
1680        let col_ptr = vec![0, 2, 3, 5];
1681        let row_indices = vec![0, 2, 1, 0, 2];
1682        let values = vec![1.0, 4.0, 3.0, 2.0, 5.0];
1683        let shape = vec![3, 3];
1684
1685        let csc = SparseCSC::new(col_ptr, row_indices, values, shape)
1686            .expect("CSC creation should succeed");
1687        let coo = csc.to_coo().expect("to COO conversion should succeed");
1688
1689        assert_eq!(coo.nnz(), 5);
1690        let dense_coo = coo
1691            .to_dense()
1692            .expect("sparse to dense conversion should succeed");
1693        let dense_csc = csc
1694            .to_dense()
1695            .expect("sparse to dense conversion should succeed");
1696        assert_eq!(
1697            dense_coo.data().expect("data access should succeed"),
1698            dense_csc.data().expect("data access should succeed")
1699        );
1700    }
1701
1702    #[test]
1703    fn test_format_conversions() {
1704        // Test round-trip: COO -> CSR -> COO -> CSC -> COO
1705        let indices = vec![vec![0, 0], vec![0, 2], vec![1, 1], vec![2, 0], vec![2, 2]];
1706        let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1707        let shape = vec![3, 3];
1708
1709        let coo1 = SparseTensor::from_coo(indices, values, shape)
1710            .expect("COO sparse tensor creation should succeed");
1711        let csr = SparseCSR::from_coo(&coo1).expect("COO sparse tensor creation should succeed");
1712        let coo2 = csr.to_coo().expect("to COO conversion should succeed");
1713        let csc = SparseCSC::from_coo(&coo2).expect("COO sparse tensor creation should succeed");
1714        let coo3 = csc.to_coo().expect("to COO conversion should succeed");
1715
1716        // All should represent the same matrix
1717        let dense1 = coo1
1718            .to_dense()
1719            .expect("sparse to dense conversion should succeed");
1720        let dense2 = coo2
1721            .to_dense()
1722            .expect("sparse to dense conversion should succeed");
1723        let dense3 = coo3
1724            .to_dense()
1725            .expect("sparse to dense conversion should succeed");
1726
1727        assert_eq!(
1728            dense1.data().expect("data access should succeed"),
1729            dense2.data().expect("data access should succeed")
1730        );
1731        assert_eq!(
1732            dense2.data().expect("data access should succeed"),
1733            dense3.data().expect("data access should succeed")
1734        );
1735    }
1736
1737    #[test]
1738    fn test_csr_error_cases() {
1739        // Wrong shape length
1740        let row_ptr = vec![0, 1];
1741        let col_indices = vec![0];
1742        let values = vec![1.0];
1743        let shape = vec![1]; // 1D not supported
1744        assert!(SparseCSR::new(row_ptr, col_indices, values, shape).is_err());
1745
1746        // Mismatched lengths
1747        let row_ptr = vec![0, 2];
1748        let col_indices = vec![0];
1749        let values = vec![1.0, 2.0]; // Different from col_indices
1750        let shape = vec![1, 2];
1751        assert!(SparseCSR::new(row_ptr, col_indices, values, shape).is_err());
1752
1753        // Non-monotonic row pointers
1754        let row_ptr = vec![0, 2, 1]; // Decreases
1755        let col_indices = vec![0, 1];
1756        let values = vec![1.0, 2.0];
1757        let shape = vec![2, 2];
1758        assert!(SparseCSR::new(row_ptr, col_indices, values, shape).is_err());
1759
1760        // Column index out of bounds
1761        let row_ptr = vec![0, 1];
1762        let col_indices = vec![5]; // Out of bounds
1763        let values = vec![1.0];
1764        let shape = vec![1, 2];
1765        assert!(SparseCSR::new(row_ptr, col_indices, values, shape).is_err());
1766    }
1767
1768    #[test]
1769    fn test_csc_error_cases() {
1770        // Wrong shape length
1771        let col_ptr = vec![0, 1];
1772        let row_indices = vec![0];
1773        let values = vec![1.0];
1774        let shape = vec![1]; // 1D not supported
1775        assert!(SparseCSC::new(col_ptr, row_indices, values, shape).is_err());
1776
1777        // Mismatched lengths
1778        let col_ptr = vec![0, 2];
1779        let row_indices = vec![0];
1780        let values = vec![1.0, 2.0]; // Different from row_indices
1781        let shape = vec![2, 1];
1782        assert!(SparseCSC::new(col_ptr, row_indices, values, shape).is_err());
1783
1784        // Non-monotonic column pointers
1785        let col_ptr = vec![0, 2, 1]; // Decreases
1786        let row_indices = vec![0, 1];
1787        let values = vec![1.0, 2.0];
1788        let shape = vec![2, 2];
1789        assert!(SparseCSC::new(col_ptr, row_indices, values, shape).is_err());
1790
1791        // Row index out of bounds
1792        let col_ptr = vec![0, 1];
1793        let row_indices = vec![5]; // Out of bounds
1794        let values = vec![1.0];
1795        let shape = vec![2, 1];
1796        assert!(SparseCSC::new(col_ptr, row_indices, values, shape).is_err());
1797    }
1798}