Skip to main content

trustformers_core/
sparse_ops.rs

1//! Advanced sparse tensor operations and structured sparsity
2//!
3//! This module provides high-performance sparse operations optimized for transformer models,
4//! including sparse matrix multiplication, structured sparsity patterns, and pruning utilities.
5//!
6//! # Features
7//!
8//! - **Sparse Matrix Multiplication**: SpMM, SpMSpM with various formats
9//! - **Structured Sparsity**: N:M sparsity, block sparsity, channel pruning
10//! - **Sparse Attention**: Memory-efficient attention for long sequences
11//! - **Pruning Utilities**: Magnitude pruning, gradient-based pruning
12//! - **Format Conversion**: Efficient COO ↔ CSR ↔ CSC ↔ BSR conversions
13//!
14//! # Examples
15//!
16//! ```rust
17//! use trustformers_core::sparse_ops::{sparse_matmul, StructuredSparsityPattern, NMSparsity};
18//! use trustformers_core::sparse_tensor::SparseTensor;
19//! use trustformers_core::tensor::Tensor;
20//!
21//! // Create structured N:M sparsity
22//! let pattern = NMSparsity::new(2, 4); // 2:4 sparsity (50%)
23//! let dense = Tensor::randn(&[128, 128])?;
24//! let sparse = pattern.apply(&dense)?;
25//!
26//! // Sparse-dense matrix multiplication
27//! let result = sparse_matmul(&sparse, &dense)?;
28//! # Ok::<(), Box<dyn std::error::Error>>(())
29//! ```
30
31use crate::errors::{Result, TrustformersError};
32use crate::sparse_tensor::{SparseFormat, SparseIndices, SparseTensor};
33use crate::tensor::Tensor;
34use serde::{Deserialize, Serialize};
35use std::collections::HashSet;
36
37/// Structured sparsity pattern types
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub enum StructuredSparsityPattern {
40    /// N:M sparsity - N non-zero elements every M elements
41    NM { n: usize, m: usize },
42
43    /// Block sparsity - blocks of size (bh, bw) are either all zero or all non-zero
44    Block {
45        block_height: usize,
46        block_width: usize,
47    },
48
49    /// Channel pruning - entire channels (columns or rows) are pruned
50    Channel { dimension: usize, keep_ratio: f32 },
51
52    /// Head pruning - prune entire attention heads
53    Head { num_heads: usize, keep_ratio: f32 },
54
55    /// Random sparsity - random elements are pruned with given sparsity
56    Random { sparsity: f32 },
57
58    /// Magnitude-based - keep top-k by magnitude
59    Magnitude { keep_ratio: f32 },
60}
61
62/// N:M structured sparsity implementation
63pub struct NMSparsity {
64    n: usize,
65    m: usize,
66}
67
68impl NMSparsity {
69    /// Create a new N:M sparsity pattern
70    ///
71    /// # Arguments
72    /// * `n` - Number of non-zero elements to keep
73    /// * `m` - Window size (n elements out of every m)
74    ///
75    /// Common patterns:
76    /// - 1:2 = 50% sparsity
77    /// - 2:4 = 50% sparsity (better for hardware)
78    /// - 1:4 = 75% sparsity
79    pub fn new(n: usize, m: usize) -> Self {
80        assert!(n <= m, "N must be <= M in N:M sparsity");
81        Self { n, m }
82    }
83
84    /// Apply N:M sparsity to a dense tensor
85    pub fn apply(&self, tensor: &Tensor) -> Result<SparseTensor> {
86        let data = tensor.to_vec_f32()?;
87        let shape = tensor.shape().to_vec();
88
89        if shape.len() != 2 {
90            return Err(TrustformersError::shape_error(
91                "N:M sparsity currently supports only 2D tensors".to_string(),
92            ));
93        }
94
95        let rows = shape[0];
96        let cols = shape[1];
97
98        // Check that columns are divisible by M
99        if cols % self.m != 0 {
100            return Err(TrustformersError::shape_error(format!(
101                "Number of columns {} must be divisible by M={}",
102                cols, self.m
103            )));
104        }
105
106        let mut row_indices = Vec::new();
107        let mut col_indices = Vec::new();
108        let mut values = Vec::new();
109
110        // Process each row
111        for row in 0..rows {
112            let row_start = row * cols;
113
114            // Process in windows of M elements
115            for window_start in (0..cols).step_by(self.m) {
116                let window_end = (window_start + self.m).min(cols);
117
118                // Collect values in this window with their original indices
119                let mut window_vals: Vec<(usize, f32)> = (window_start..window_end)
120                    .map(|col| {
121                        let idx = row_start + col;
122                        (col, data[idx])
123                    })
124                    .collect();
125
126                // Sort by absolute value (descending)
127                window_vals.sort_by(|a, b| {
128                    b.1.abs().partial_cmp(&a.1.abs()).unwrap_or(std::cmp::Ordering::Equal)
129                });
130
131                // Keep top N values
132                for (col, val) in window_vals.iter().take(self.n) {
133                    row_indices.push(row);
134                    col_indices.push(*col);
135                    values.push(*val);
136                }
137            }
138        }
139
140        SparseTensor::new_coo(shape, row_indices, col_indices, values)
141    }
142
143    /// Get theoretical sparsity ratio
144    pub fn sparsity_ratio(&self) -> f32 {
145        1.0 - (self.n as f32 / self.m as f32)
146    }
147}
148
149/// Block sparsity implementation
150pub struct BlockSparsity {
151    block_height: usize,
152    block_width: usize,
153    keep_ratio: f32,
154}
155
156impl BlockSparsity {
157    /// Create a new block sparsity pattern
158    pub fn new(block_height: usize, block_width: usize, keep_ratio: f32) -> Self {
159        Self {
160            block_height,
161            block_width,
162            keep_ratio,
163        }
164    }
165
166    /// Apply block sparsity to a dense tensor
167    pub fn apply(&self, tensor: &Tensor) -> Result<SparseTensor> {
168        let data = tensor.to_vec_f32()?;
169        let shape = tensor.shape().to_vec();
170
171        if shape.len() != 2 {
172            return Err(TrustformersError::shape_error(
173                "Block sparsity currently supports only 2D tensors".to_string(),
174            ));
175        }
176
177        let rows = shape[0];
178        let cols = shape[1];
179
180        let num_block_rows = rows.div_ceil(self.block_height);
181        let num_block_cols = cols.div_ceil(self.block_width);
182
183        // Compute importance score for each block
184        let mut block_scores = Vec::new();
185        for br in 0..num_block_rows {
186            for bc in 0..num_block_cols {
187                let row_start = br * self.block_height;
188                let row_end = (row_start + self.block_height).min(rows);
189                let col_start = bc * self.block_width;
190                let col_end = (col_start + self.block_width).min(cols);
191
192                // Compute L2 norm of block
193                let mut block_norm = 0.0f32;
194                for r in row_start..row_end {
195                    for c in col_start..col_end {
196                        let val = data[r * cols + c];
197                        block_norm += val * val;
198                    }
199                }
200                block_norm = block_norm.sqrt();
201
202                block_scores.push(((br, bc), block_norm));
203            }
204        }
205
206        // Sort blocks by importance
207        block_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).expect("Partial comparison failed"));
208
209        // Keep top blocks
210        let num_blocks_to_keep = ((block_scores.len() as f32) * self.keep_ratio) as usize;
211        let blocks_to_keep: HashSet<(usize, usize)> = block_scores
212            .iter()
213            .take(num_blocks_to_keep)
214            .map(|&((br, bc), _)| (br, bc))
215            .collect();
216
217        // Build sparse tensor from kept blocks
218        let mut row_ptr = vec![0];
219        let mut col_indices = Vec::new();
220        let mut values = Vec::new();
221
222        for br in 0..num_block_rows {
223            let row_start = br * self.block_height;
224            let row_end = (row_start + self.block_height).min(rows);
225
226            for r in row_start..row_end {
227                let mut row_nnz = 0;
228
229                for bc in 0..num_block_cols {
230                    if !blocks_to_keep.contains(&(br, bc)) {
231                        continue;
232                    }
233
234                    let col_start = bc * self.block_width;
235                    let col_end = (col_start + self.block_width).min(cols);
236
237                    for c in col_start..col_end {
238                        let val = data[r * cols + c];
239                        if val != 0.0 {
240                            col_indices.push(c);
241                            values.push(val);
242                            row_nnz += 1;
243                        }
244                    }
245                }
246
247                // row_ptr is never empty - initialized with 0 at start
248                row_ptr.push(row_ptr.last().copied().unwrap_or(0) + row_nnz);
249            }
250        }
251
252        SparseTensor::new_csr(shape, row_ptr, col_indices, values)
253    }
254}
255
256/// Sparse matrix - dense matrix multiplication
257pub fn sparse_matmul(sparse: &SparseTensor, dense: &Tensor) -> Result<Tensor> {
258    let dense_data = dense.to_vec_f32()?;
259    let dense_shape = dense.shape();
260
261    if sparse.shape.len() != 2 || dense_shape.len() != 2 {
262        return Err(TrustformersError::shape_error(
263            "Sparse matmul requires 2D matrices".to_string(),
264        ));
265    }
266
267    if sparse.shape[1] != dense_shape[0] {
268        return Err(TrustformersError::shape_error(format!(
269            "Incompatible shapes for matmul: {:?} x {:?}",
270            sparse.shape, dense_shape
271        )));
272    }
273
274    let m = sparse.shape[0];
275    let _k = sparse.shape[1];
276    let n = dense_shape[1];
277
278    let mut result = vec![0.0f32; m * n];
279
280    match sparse.format {
281        SparseFormat::CSR => {
282            if let SparseIndices::CSR {
283                row_ptr,
284                col_indices,
285            } = &sparse.indices
286            {
287                // CSR format is optimal for SpMM
288                for row in 0..m {
289                    let row_start = row_ptr[row];
290                    let row_end = row_ptr[row + 1];
291
292                    #[allow(clippy::needless_range_loop)]
293                    for j in row_start..row_end {
294                        let col = col_indices[j];
295                        let sparse_val = sparse.values[j];
296
297                        // Compute dot product contribution
298                        for out_col in 0..n {
299                            result[row * n + out_col] += sparse_val * dense_data[col * n + out_col];
300                        }
301                    }
302                }
303            } else {
304                return Err(TrustformersError::tensor_op_error(
305                    "Invalid indices format",
306                    "sparse matmul",
307                ));
308            }
309        },
310        SparseFormat::COO => {
311            if let SparseIndices::COO {
312                row_indices,
313                col_indices,
314            } = &sparse.indices
315            {
316                for ((&row, &col), &val) in
317                    row_indices.iter().zip(col_indices.iter()).zip(sparse.values.iter())
318                {
319                    for out_col in 0..n {
320                        result[row * n + out_col] += val * dense_data[col * n + out_col];
321                    }
322                }
323            } else {
324                return Err(TrustformersError::tensor_op_error(
325                    "Invalid indices format",
326                    "sparse matmul",
327                ));
328            }
329        },
330        _ => {
331            return Err(TrustformersError::tensor_op_error(
332                "Unsupported sparse format for matmul",
333                "sparse matmul",
334            ));
335        },
336    }
337
338    Tensor::from_vec(result, &[m, n])
339}
340
341/// Sparse attention utilities
342pub mod sparse_attention {
343    use super::*;
344
345    /// Block-sparse attention pattern
346    pub struct BlockSparseAttention {
347        block_size: usize,
348        num_random_blocks: usize,
349    }
350
351    impl BlockSparseAttention {
352        /// Create a new block-sparse attention pattern
353        pub fn new(block_size: usize, num_random_blocks: usize) -> Self {
354            Self {
355                block_size,
356                num_random_blocks,
357            }
358        }
359
360        /// Generate attention mask for block-sparse pattern
361        pub fn generate_mask(&self, seq_len: usize) -> Result<SparseTensor> {
362            let num_blocks = seq_len.div_ceil(self.block_size);
363
364            let mut row_indices = Vec::new();
365            let mut col_indices = Vec::new();
366            let mut values = Vec::new();
367
368            for block_i in 0..num_blocks {
369                // Local attention (diagonal blocks)
370                for block_j in block_i.saturating_sub(1)..=(block_i + 1).min(num_blocks - 1) {
371                    self.add_block(
372                        block_i,
373                        block_j,
374                        seq_len,
375                        &mut row_indices,
376                        &mut col_indices,
377                        &mut values,
378                    );
379                }
380
381                // Random global attention (using simple deterministic pattern for now)
382                // TODO: Use proper RNG when scirs2_core Random API is clearer
383                for j in 0..self.num_random_blocks {
384                    let random_block = (block_i * 7 + j * 13) % num_blocks;
385                    self.add_block(
386                        block_i,
387                        random_block,
388                        seq_len,
389                        &mut row_indices,
390                        &mut col_indices,
391                        &mut values,
392                    );
393                }
394            }
395
396            SparseTensor::new_coo(vec![seq_len, seq_len], row_indices, col_indices, values)
397        }
398
399        fn add_block(
400            &self,
401            block_i: usize,
402            block_j: usize,
403            seq_len: usize,
404            row_indices: &mut Vec<usize>,
405            col_indices: &mut Vec<usize>,
406            values: &mut Vec<f32>,
407        ) {
408            let row_start = block_i * self.block_size;
409            let row_end = (row_start + self.block_size).min(seq_len);
410            let col_start = block_j * self.block_size;
411            let col_end = (col_start + self.block_size).min(seq_len);
412
413            for r in row_start..row_end {
414                for c in col_start..col_end {
415                    row_indices.push(r);
416                    col_indices.push(c);
417                    values.push(1.0); // Attention mask value
418                }
419            }
420        }
421    }
422
423    /// Sliding window attention pattern
424    pub fn sliding_window_mask(seq_len: usize, window_size: usize) -> Result<SparseTensor> {
425        let mut row_indices = Vec::new();
426        let mut col_indices = Vec::new();
427        let mut values = Vec::new();
428
429        for i in 0..seq_len {
430            let start = i.saturating_sub(window_size / 2);
431            let end = (i + window_size / 2 + 1).min(seq_len);
432
433            for j in start..end {
434                row_indices.push(i);
435                col_indices.push(j);
436                values.push(1.0);
437            }
438        }
439
440        SparseTensor::new_coo(vec![seq_len, seq_len], row_indices, col_indices, values)
441    }
442
443    /// Dilated sliding window (for longer-range dependencies)
444    pub fn dilated_window_mask(
445        seq_len: usize,
446        window_size: usize,
447        dilation: usize,
448    ) -> Result<SparseTensor> {
449        let mut row_indices = Vec::new();
450        let mut col_indices = Vec::new();
451        let mut values = Vec::new();
452
453        for i in 0..seq_len {
454            // Local window
455            let local_start = i.saturating_sub(window_size / 2);
456            let local_end = (i + window_size / 2 + 1).min(seq_len);
457
458            for j in local_start..local_end {
459                row_indices.push(i);
460                col_indices.push(j);
461                values.push(1.0);
462            }
463
464            // Dilated positions
465            for k in 1..=window_size {
466                let dilated_pos = i + k * dilation;
467                if dilated_pos < seq_len {
468                    row_indices.push(i);
469                    col_indices.push(dilated_pos);
470                    values.push(1.0);
471                }
472
473                if k * dilation <= i {
474                    let dilated_pos = i - k * dilation;
475                    row_indices.push(i);
476                    col_indices.push(dilated_pos);
477                    values.push(1.0);
478                }
479            }
480        }
481
482        SparseTensor::new_coo(vec![seq_len, seq_len], row_indices, col_indices, values)
483    }
484}
485
486/// Format conversion utilities
487pub mod conversion {
488    use super::*;
489
490    /// Convert COO to CSR format
491    pub fn coo_to_csr(sparse: &SparseTensor) -> Result<SparseTensor> {
492        if sparse.format != SparseFormat::COO {
493            return Err(TrustformersError::tensor_op_error(
494                "Input must be in COO format",
495                "COO to CSR conversion",
496            ));
497        }
498
499        if let SparseIndices::COO {
500            row_indices,
501            col_indices,
502        } = &sparse.indices
503        {
504            let num_rows = sparse.shape[0];
505
506            // Build row_ptr
507            let mut row_ptr = vec![0; num_rows + 1];
508            for &row in row_indices {
509                row_ptr[row + 1] += 1;
510            }
511
512            // Cumulative sum
513            for i in 0..num_rows {
514                row_ptr[i + 1] += row_ptr[i];
515            }
516
517            // Sort entries by row, then by column
518            let mut entries: Vec<(usize, usize, f32)> = row_indices
519                .iter()
520                .zip(col_indices.iter())
521                .zip(sparse.values.iter())
522                .map(|((&r, &c), &v)| (r, c, v))
523                .collect();
524
525            entries.sort_by_key(|&(r, c, _)| (r, c));
526
527            let sorted_col_indices: Vec<usize> = entries.iter().map(|&(_, c, _)| c).collect();
528            let sorted_values: Vec<f32> = entries.iter().map(|&(_, _, v)| v).collect();
529
530            SparseTensor::new_csr(
531                sparse.shape.clone(),
532                row_ptr,
533                sorted_col_indices,
534                sorted_values,
535            )
536        } else {
537            Err(TrustformersError::tensor_op_error(
538                "Invalid indices format",
539                "COO to CSR conversion",
540            ))
541        }
542    }
543
544    /// Convert CSR to COO format
545    pub fn csr_to_coo(sparse: &SparseTensor) -> Result<SparseTensor> {
546        if sparse.format != SparseFormat::CSR {
547            return Err(TrustformersError::tensor_op_error(
548                "Input must be in CSR format",
549                "CSR to COO conversion",
550            ));
551        }
552
553        if let SparseIndices::CSR {
554            row_ptr,
555            col_indices,
556        } = &sparse.indices
557        {
558            let mut row_indices = Vec::new();
559
560            for (row, window) in row_ptr.windows(2).enumerate() {
561                let count = window[1] - window[0];
562                row_indices.extend(vec![row; count]);
563            }
564
565            SparseTensor::new_coo(
566                sparse.shape.clone(),
567                row_indices,
568                col_indices.clone(),
569                sparse.values.clone(),
570            )
571        } else {
572            Err(TrustformersError::tensor_op_error(
573                "Invalid indices format",
574                "CSR to COO conversion",
575            ))
576        }
577    }
578}
579
580/// Pruning utilities
581pub mod pruning {
582    use super::*;
583
584    /// Magnitude-based pruning
585    pub fn magnitude_prune(tensor: &Tensor, keep_ratio: f32) -> Result<SparseTensor> {
586        let data = tensor.to_vec_f32()?;
587        let shape = tensor.shape().to_vec();
588
589        // Sort by magnitude
590        let mut indexed_data: Vec<(usize, f32)> =
591            data.iter().enumerate().map(|(i, &v)| (i, v)).collect();
592        indexed_data
593            .sort_by(|a, b| b.1.abs().partial_cmp(&a.1.abs()).unwrap_or(std::cmp::Ordering::Equal));
594
595        // Keep top-k
596        let num_keep = ((data.len() as f32) * keep_ratio) as usize;
597        let keep_indices: HashSet<usize> =
598            indexed_data.iter().take(num_keep).map(|&(idx, _)| idx).collect();
599
600        // Build sparse tensor
601        if shape.len() == 2 {
602            let cols = shape[1];
603            let mut row_indices = Vec::new();
604            let mut col_indices = Vec::new();
605            let mut values = Vec::new();
606
607            for idx in keep_indices {
608                let row = idx / cols;
609                let col = idx % cols;
610                row_indices.push(row);
611                col_indices.push(col);
612                values.push(data[idx]);
613            }
614
615            SparseTensor::new_coo(shape, row_indices, col_indices, values)
616        } else {
617            Err(TrustformersError::shape_error(
618                "Pruning currently supports only 2D tensors".to_string(),
619            ))
620        }
621    }
622
623    /// Gradient-based pruning (requires gradient information)
624    pub fn gradient_based_prune(
625        tensor: &Tensor,
626        gradients: &Tensor,
627        keep_ratio: f32,
628    ) -> Result<SparseTensor> {
629        let weights = tensor.to_vec_f32()?;
630        let grads = gradients.to_vec_f32()?;
631        let shape = tensor.shape().to_vec();
632
633        if weights.len() != grads.len() {
634            return Err(TrustformersError::shape_error(
635                "Weight and gradient shapes must match".to_string(),
636            ));
637        }
638
639        // Compute importance score: |weight * gradient|
640        let mut scores: Vec<(usize, f32)> = weights
641            .iter()
642            .zip(grads.iter())
643            .enumerate()
644            .map(|(i, (&w, &g))| (i, (w * g).abs()))
645            .collect();
646
647        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).expect("Partial comparison failed"));
648
649        let num_keep = ((weights.len() as f32) * keep_ratio) as usize;
650        let keep_indices: HashSet<usize> =
651            scores.iter().take(num_keep).map(|&(idx, _)| idx).collect();
652
653        // Build sparse tensor
654        if shape.len() == 2 {
655            let cols = shape[1];
656            let mut row_indices = Vec::new();
657            let mut col_indices = Vec::new();
658            let mut values = Vec::new();
659
660            for idx in keep_indices {
661                let row = idx / cols;
662                let col = idx % cols;
663                row_indices.push(row);
664                col_indices.push(col);
665                values.push(weights[idx]);
666            }
667
668            SparseTensor::new_coo(shape, row_indices, col_indices, values)
669        } else {
670            Err(TrustformersError::shape_error(
671                "Pruning currently supports only 2D tensors".to_string(),
672            ))
673        }
674    }
675}
676
677#[cfg(test)]
678mod tests {
679    use super::*;
680
681    #[test]
682    fn test_nm_sparsity() -> Result<()> {
683        let nm = NMSparsity::new(2, 4);
684        assert_eq!(nm.sparsity_ratio(), 0.5);
685
686        // Create a test tensor
687        let data: Vec<f32> = (0..64).map(|i| i as f32).collect();
688        let tensor = Tensor::from_vec(data, &[8, 8])?;
689
690        let sparse = nm.apply(&tensor)?;
691
692        // Check that we have 50% sparsity
693        let expected_nnz = 8 * 8 / 2; // 50% of 64
694        assert_eq!(sparse.nnz, expected_nnz);
695
696        Ok(())
697    }
698
699    #[test]
700    fn test_sparse_matmul() -> Result<()> {
701        // Create a sparse matrix
702        let sparse = SparseTensor::new_coo(
703            vec![3, 3],
704            vec![0, 0, 1, 2],
705            vec![0, 1, 1, 2],
706            vec![1.0, 2.0, 3.0, 4.0],
707        )?;
708
709        // Create a dense matrix
710        let dense_data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
711        let dense = Tensor::from_vec(dense_data, &[3, 2])?;
712
713        // Multiply
714        let result = sparse_matmul(&sparse, &dense)?;
715
716        assert_eq!(result.shape(), &[3, 2]);
717
718        Ok(())
719    }
720
721    #[test]
722    fn test_block_sparsity() -> Result<()> {
723        let block_sparse = BlockSparsity::new(2, 2, 0.5);
724
725        let data: Vec<f32> = (0..64).map(|i| i as f32).collect();
726        let tensor = Tensor::from_vec(data, &[8, 8])?;
727
728        let sparse = block_sparse.apply(&tensor)?;
729
730        // Should keep approximately 50% of blocks
731        assert!(sparse.nnz > 0);
732        assert!(sparse.nnz < 64);
733
734        Ok(())
735    }
736
737    #[test]
738    fn test_sliding_window_mask() -> Result<()> {
739        let mask = sparse_attention::sliding_window_mask(100, 10)?;
740
741        // Each position attends to window_size/2 positions on each side plus itself
742        // So each position has approximately window_size + 1 elements
743        // The total is bounded by seq_len * (window_size + 1) for middle positions
744        // Edge positions have fewer elements due to boundary effects
745        assert!(mask.nnz <= 100 * 11);
746        assert!(mask.nnz > 0);
747
748        Ok(())
749    }
750
751    #[test]
752    fn test_magnitude_pruning() -> Result<()> {
753        let data: Vec<f32> = (0..64).map(|i| (i as f32) - 32.0).collect();
754        let tensor = Tensor::from_vec(data, &[8, 8])?;
755
756        let sparse = pruning::magnitude_prune(&tensor, 0.25)?;
757
758        // Should keep 25% of elements
759        assert_eq!(sparse.nnz, 16);
760
761        Ok(())
762    }
763
764    #[test]
765    fn test_coo_to_csr_conversion() -> Result<()> {
766        let coo = SparseTensor::new_coo(
767            vec![3, 3],
768            vec![0, 0, 1, 2],
769            vec![0, 1, 1, 2],
770            vec![1.0, 2.0, 3.0, 4.0],
771        )?;
772
773        let csr = conversion::coo_to_csr(&coo)?;
774
775        assert_eq!(csr.format, SparseFormat::CSR);
776        assert_eq!(csr.nnz, 4);
777
778        // Convert back
779        let coo2 = conversion::csr_to_coo(&csr)?;
780        assert_eq!(coo2.format, SparseFormat::COO);
781        assert_eq!(coo2.nnz, 4);
782
783        Ok(())
784    }
785}