Skip to main content

torsh_backend/
sparse_ops.rs

1//! Comprehensive sparse operations support for ToRSh backends
2//!
3//! This module provides efficient sparse matrix and tensor operations, including
4//! different storage formats and optimized kernels for sparse computations.
5//!
6//! Supported formats:
7//! - COO (Coordinate format)
8//! - CSR (Compressed Sparse Row)
9//! - CSC (Compressed Sparse Column)
10//! - BSR (Block Sparse Row)
11//! - Hybrid formats for mixed workloads
12
13use crate::{BackendResult, Device};
14use std::collections::HashMap;
15use torsh_core::error::TorshError;
16
17#[cfg(not(feature = "std"))]
18use alloc::{string::String, vec::Vec};
19
20/// Sparse matrix storage format
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum SparseFormat {
23    /// Coordinate format (COO) - stores (row, col, value) triplets
24    Coo,
25    /// Compressed Sparse Row (CSR) format
26    Csr,
27    /// Compressed Sparse Column (CSC) format
28    Csc,
29    /// Block Sparse Row (BSR) format - for structured sparsity
30    Bsr,
31    /// Dense format (for comparison and conversion)
32    Dense,
33}
34
35/// Sparse matrix structure
36#[derive(Debug, Clone)]
37pub struct SparseMatrix<T> {
38    /// Storage format
39    pub format: SparseFormat,
40    /// Number of rows
41    pub rows: usize,
42    /// Number of columns
43    pub cols: usize,
44    /// Number of non-zero elements
45    pub nnz: usize,
46    /// Values of non-zero elements
47    pub values: Vec<T>,
48    /// Row indices (format-dependent meaning)
49    pub row_indices: Vec<usize>,
50    /// Column indices (format-dependent meaning)
51    pub col_indices: Vec<usize>,
52    /// Block size for BSR format
53    pub block_size: Option<(usize, usize)>,
54}
55
56impl<T> Default for SparseMatrix<T> {
57    fn default() -> Self {
58        Self {
59            format: SparseFormat::Coo,
60            rows: 0,
61            cols: 0,
62            nnz: 0,
63            values: Vec::new(),
64            row_indices: Vec::new(),
65            col_indices: Vec::new(),
66            block_size: None,
67        }
68    }
69}
70
71impl<T: Clone + Default + PartialEq> SparseMatrix<T> {
72    /// Create a new sparse matrix in COO format
73    pub fn new_coo(rows: usize, cols: usize) -> Self {
74        Self {
75            format: SparseFormat::Coo,
76            rows,
77            cols,
78            nnz: 0,
79            values: Vec::new(),
80            row_indices: Vec::new(),
81            col_indices: Vec::new(),
82            block_size: None,
83        }
84    }
85
86    /// Create a new sparse matrix in CSR format
87    pub fn new_csr(rows: usize, cols: usize) -> Self {
88        Self {
89            format: SparseFormat::Csr,
90            rows,
91            cols,
92            nnz: 0,
93            values: Vec::new(),
94            row_indices: Vec::with_capacity(rows + 1), // row_ptr array
95            col_indices: Vec::new(),
96            block_size: None,
97        }
98    }
99
100    /// Create a new sparse matrix in CSC format
101    pub fn new_csc(rows: usize, cols: usize) -> Self {
102        Self {
103            format: SparseFormat::Csc,
104            rows,
105            cols,
106            nnz: 0,
107            values: Vec::new(),
108            row_indices: Vec::new(),
109            col_indices: Vec::with_capacity(cols + 1), // col_ptr array
110            block_size: None,
111        }
112    }
113
114    /// Insert a value at (row, col) for COO format
115    pub fn insert_coo(&mut self, row: usize, col: usize, value: T) -> BackendResult<()> {
116        if self.format != SparseFormat::Coo {
117            return Err(TorshError::ComputeError(
118                "Matrix is not in COO format".to_string(),
119            ));
120        }
121
122        if row >= self.rows || col >= self.cols {
123            return Err(TorshError::ComputeError("Index out of bounds".to_string()));
124        }
125
126        // For simplicity, we always append (real implementation would handle duplicates)
127        self.row_indices.push(row);
128        self.col_indices.push(col);
129        self.values.push(value);
130        self.nnz += 1;
131
132        Ok(())
133    }
134
135    /// Convert COO to CSR format
136    pub fn to_csr(&self) -> BackendResult<SparseMatrix<T>> {
137        if self.format != SparseFormat::Coo {
138            return Err(TorshError::ComputeError(
139                "Source matrix must be in COO format".to_string(),
140            ));
141        }
142
143        let mut csr = SparseMatrix::new_csr(self.rows, self.cols);
144        csr.nnz = self.nnz;
145
146        if self.nnz == 0 {
147            // Initialize empty row_ptr array
148            csr.row_indices = vec![0; self.rows + 1];
149            return Ok(csr);
150        }
151
152        // Count non-zeros per row
153        let mut row_counts = vec![0; self.rows];
154        for &row in &self.row_indices {
155            row_counts[row] += 1;
156        }
157
158        // Build row_ptr array (cumulative sum)
159        csr.row_indices.push(0);
160        for count in row_counts {
161            let last = *csr
162                .row_indices
163                .last()
164                .expect("row_indices should not be empty after initial push");
165            csr.row_indices.push(last + count);
166        }
167
168        // Sort entries by row, then by column
169        let mut triplets: Vec<(usize, usize, T)> = self
170            .row_indices
171            .iter()
172            .zip(self.col_indices.iter())
173            .zip(self.values.iter())
174            .map(|((&r, &c), v)| (r, c, v.clone()))
175            .collect();
176
177        triplets.sort_by(|a, b| a.0.cmp(&b.0).then(a.1.cmp(&b.1)));
178
179        // Fill values and col_indices
180        csr.values.reserve(self.nnz);
181        csr.col_indices.reserve(self.nnz);
182
183        for (_, col, value) in triplets {
184            csr.col_indices.push(col);
185            csr.values.push(value);
186        }
187
188        Ok(csr)
189    }
190
191    /// Convert COO to CSC format
192    pub fn to_csc(&self) -> BackendResult<SparseMatrix<T>> {
193        if self.format != SparseFormat::Coo {
194            return Err(TorshError::ComputeError(
195                "Source matrix must be in COO format".to_string(),
196            ));
197        }
198
199        let mut csc = SparseMatrix::new_csc(self.rows, self.cols);
200        csc.nnz = self.nnz;
201
202        if self.nnz == 0 {
203            // Initialize empty col_ptr array
204            csc.col_indices = vec![0; self.cols + 1];
205            return Ok(csc);
206        }
207
208        // Count non-zeros per column
209        let mut col_counts = vec![0; self.cols];
210        for &col in &self.col_indices {
211            col_counts[col] += 1;
212        }
213
214        // Build col_ptr array (cumulative sum)
215        csc.col_indices.push(0);
216        for count in col_counts {
217            let last = *csc
218                .col_indices
219                .last()
220                .expect("col_indices should not be empty after initial push");
221            csc.col_indices.push(last + count);
222        }
223
224        // Sort entries by column, then by row
225        let mut triplets: Vec<(usize, usize, T)> = self
226            .row_indices
227            .iter()
228            .zip(self.col_indices.iter())
229            .zip(self.values.iter())
230            .map(|((&r, &c), v)| (r, c, v.clone()))
231            .collect();
232
233        triplets.sort_by(|a, b| a.1.cmp(&b.1).then(a.0.cmp(&b.0)));
234
235        // Fill values and row_indices
236        csc.values.reserve(self.nnz);
237        csc.row_indices.reserve(self.nnz);
238
239        for (row, _, value) in triplets {
240            csc.row_indices.push(row);
241            csc.values.push(value);
242        }
243
244        Ok(csc)
245    }
246
247    /// Get sparsity ratio (percentage of non-zero elements)
248    pub fn sparsity_ratio(&self) -> f64 {
249        if self.rows == 0 || self.cols == 0 {
250            return 0.0;
251        }
252        self.nnz as f64 / (self.rows * self.cols) as f64
253    }
254
255    /// Check if the matrix is effectively sparse (< 50% non-zero)
256    pub fn is_sparse(&self) -> bool {
257        self.sparsity_ratio() < 0.5
258    }
259}
260
261/// Sparse operations trait for different backends
262pub trait SparseOps<T> {
263    /// Sparse matrix-vector multiplication: y = A * x
264    fn spmv(&self, matrix: &SparseMatrix<T>, x: &[T], y: &mut [T]) -> BackendResult<()>;
265
266    /// Sparse matrix-matrix multiplication: C = A * B
267    fn spmm(&self, a: &SparseMatrix<T>, b: &SparseMatrix<T>) -> BackendResult<SparseMatrix<T>>;
268
269    /// Sparse matrix addition: C = A + B
270    fn sparse_add(
271        &self,
272        a: &SparseMatrix<T>,
273        b: &SparseMatrix<T>,
274    ) -> BackendResult<SparseMatrix<T>>;
275
276    /// Convert sparse matrix to dense format
277    fn to_dense(&self, matrix: &SparseMatrix<T>) -> BackendResult<Vec<T>>;
278
279    /// Create sparse matrix from dense format
280    fn from_dense(
281        &self,
282        dense: &[T],
283        rows: usize,
284        cols: usize,
285        threshold: T,
286    ) -> BackendResult<SparseMatrix<T>>;
287
288    /// Transpose sparse matrix
289    fn transpose(&self, matrix: &SparseMatrix<T>) -> BackendResult<SparseMatrix<T>>;
290}
291
292/// Default sparse operations implementation
293#[derive(Debug)]
294pub struct DefaultSparseOps {
295    /// Device for operations
296    #[allow(dead_code)]
297    device: Device,
298    /// Optimization hints
299    optimization_hints: SparseOptimizationHints,
300}
301
302impl DefaultSparseOps {
303    /// Create new sparse operations instance
304    pub fn new(device: Device) -> Self {
305        Self {
306            device,
307            optimization_hints: SparseOptimizationHints::default(),
308        }
309    }
310
311    /// Set optimization hints
312    pub fn with_hints(mut self, hints: SparseOptimizationHints) -> Self {
313        self.optimization_hints = hints;
314        self
315    }
316}
317
318impl SparseOps<f32> for DefaultSparseOps {
319    fn spmv(&self, matrix: &SparseMatrix<f32>, x: &[f32], y: &mut [f32]) -> BackendResult<()> {
320        if x.len() != matrix.cols || y.len() != matrix.rows {
321            return Err(TorshError::ComputeError("Dimension mismatch".to_string()));
322        }
323
324        // Initialize output to zero
325        y.fill(0.0);
326
327        match matrix.format {
328            SparseFormat::Csr => self.spmv_csr(matrix, x, y),
329            SparseFormat::Coo => self.spmv_coo(matrix, x, y),
330            SparseFormat::Csc => self.spmv_csc(matrix, x, y),
331            _ => Err(TorshError::ComputeError(
332                "Unsupported sparse format for SpMV".to_string(),
333            )),
334        }
335    }
336
337    fn spmm(
338        &self,
339        a: &SparseMatrix<f32>,
340        b: &SparseMatrix<f32>,
341    ) -> BackendResult<SparseMatrix<f32>> {
342        if a.cols != b.rows {
343            return Err(TorshError::ComputeError(
344                "Matrix dimensions incompatible for multiplication".to_string(),
345            ));
346        }
347
348        // For simplicity, convert both to CSR format
349        let a_csr = if a.format == SparseFormat::Csr {
350            a.clone()
351        } else {
352            a.to_csr()?
353        };
354
355        let b_csr = if b.format == SparseFormat::Csr {
356            b.clone()
357        } else {
358            b.to_csr()?
359        };
360
361        self.spmm_csr_csr(&a_csr, &b_csr)
362    }
363
364    fn sparse_add(
365        &self,
366        a: &SparseMatrix<f32>,
367        b: &SparseMatrix<f32>,
368    ) -> BackendResult<SparseMatrix<f32>> {
369        if a.rows != b.rows || a.cols != b.cols {
370            return Err(TorshError::ComputeError(
371                "Matrix dimensions must match for addition".to_string(),
372            ));
373        }
374
375        // Convert both to COO format for easier addition
376        let a_coo = if a.format == SparseFormat::Coo {
377            a.clone()
378        } else {
379            // For now, only support COO to CSR conversion
380            return Err(TorshError::ComputeError(
381                "Sparse addition requires COO format".to_string(),
382            ));
383        };
384
385        let b_coo = if b.format == SparseFormat::Coo {
386            b.clone()
387        } else {
388            return Err(TorshError::ComputeError(
389                "Sparse addition requires COO format".to_string(),
390            ));
391        };
392
393        self.sparse_add_coo(&a_coo, &b_coo)
394    }
395
396    fn to_dense(&self, matrix: &SparseMatrix<f32>) -> BackendResult<Vec<f32>> {
397        let mut dense = vec![0.0; matrix.rows * matrix.cols];
398
399        match matrix.format {
400            SparseFormat::Coo => {
401                for i in 0..matrix.nnz {
402                    let row = matrix.row_indices[i];
403                    let col = matrix.col_indices[i];
404                    let val = matrix.values[i];
405                    dense[row * matrix.cols + col] = val;
406                }
407            }
408            SparseFormat::Csr => {
409                for row in 0..matrix.rows {
410                    let start = matrix.row_indices[row];
411                    let end = matrix.row_indices[row + 1];
412                    for idx in start..end {
413                        let col = matrix.col_indices[idx];
414                        let val = matrix.values[idx];
415                        dense[row * matrix.cols + col] = val;
416                    }
417                }
418            }
419            SparseFormat::Csc => {
420                for col in 0..matrix.cols {
421                    let start = matrix.col_indices[col];
422                    let end = matrix.col_indices[col + 1];
423                    for idx in start..end {
424                        let row = matrix.row_indices[idx];
425                        let val = matrix.values[idx];
426                        dense[row * matrix.cols + col] = val;
427                    }
428                }
429            }
430            _ => {
431                return Err(TorshError::ComputeError(
432                    "Unsupported format for dense conversion".to_string(),
433                ))
434            }
435        }
436
437        Ok(dense)
438    }
439
440    fn from_dense(
441        &self,
442        dense: &[f32],
443        rows: usize,
444        cols: usize,
445        threshold: f32,
446    ) -> BackendResult<SparseMatrix<f32>> {
447        if dense.len() != rows * cols {
448            return Err(TorshError::ComputeError(
449                "Dense array size doesn't match dimensions".to_string(),
450            ));
451        }
452
453        let mut sparse = SparseMatrix::new_coo(rows, cols);
454
455        for row in 0..rows {
456            for col in 0..cols {
457                let val = dense[row * cols + col];
458                if val.abs() > threshold {
459                    sparse.insert_coo(row, col, val)?;
460                }
461            }
462        }
463
464        Ok(sparse)
465    }
466
467    fn transpose(&self, matrix: &SparseMatrix<f32>) -> BackendResult<SparseMatrix<f32>> {
468        match matrix.format {
469            SparseFormat::Coo => {
470                let mut transposed = SparseMatrix::new_coo(matrix.cols, matrix.rows);
471                transposed.nnz = matrix.nnz;
472
473                // Swap row and column indices
474                transposed.row_indices = matrix.col_indices.clone();
475                transposed.col_indices = matrix.row_indices.clone();
476                transposed.values = matrix.values.clone();
477
478                Ok(transposed)
479            }
480            SparseFormat::Csr => {
481                // CSR transpose becomes CSC
482                let mut transposed = SparseMatrix::new_csc(matrix.cols, matrix.rows);
483                transposed.nnz = matrix.nnz;
484                transposed.values = matrix.values.clone();
485                transposed.row_indices = matrix.col_indices.clone();
486                transposed.col_indices = matrix.row_indices.clone();
487                Ok(transposed)
488            }
489            SparseFormat::Csc => {
490                // CSC transpose becomes CSR
491                let mut transposed = SparseMatrix::new_csr(matrix.cols, matrix.rows);
492                transposed.nnz = matrix.nnz;
493                transposed.values = matrix.values.clone();
494                transposed.row_indices = matrix.col_indices.clone();
495                transposed.col_indices = matrix.row_indices.clone();
496                Ok(transposed)
497            }
498            _ => Err(TorshError::ComputeError(
499                "Unsupported format for transpose".to_string(),
500            )),
501        }
502    }
503}
504
505impl DefaultSparseOps {
506    /// CSR format SpMV implementation
507    fn spmv_csr(&self, matrix: &SparseMatrix<f32>, x: &[f32], y: &mut [f32]) -> BackendResult<()> {
508        for row in 0..matrix.rows {
509            let start = matrix.row_indices[row];
510            let end = matrix.row_indices[row + 1];
511            let mut sum = 0.0;
512
513            for idx in start..end {
514                let col = matrix.col_indices[idx];
515                let val = matrix.values[idx];
516                sum += val * x[col];
517            }
518
519            y[row] = sum;
520        }
521        Ok(())
522    }
523
524    /// COO format SpMV implementation
525    fn spmv_coo(&self, matrix: &SparseMatrix<f32>, x: &[f32], y: &mut [f32]) -> BackendResult<()> {
526        for i in 0..matrix.nnz {
527            let row = matrix.row_indices[i];
528            let col = matrix.col_indices[i];
529            let val = matrix.values[i];
530            y[row] += val * x[col];
531        }
532        Ok(())
533    }
534
535    /// CSC format SpMV implementation
536    fn spmv_csc(&self, matrix: &SparseMatrix<f32>, x: &[f32], y: &mut [f32]) -> BackendResult<()> {
537        for col in 0..matrix.cols {
538            let start = matrix.col_indices[col];
539            let end = matrix.col_indices[col + 1];
540            let x_val = x[col];
541
542            for idx in start..end {
543                let row = matrix.row_indices[idx];
544                let val = matrix.values[idx];
545                y[row] += val * x_val;
546            }
547        }
548        Ok(())
549    }
550
551    /// CSR x CSR matrix multiplication
552    fn spmm_csr_csr(
553        &self,
554        a: &SparseMatrix<f32>,
555        b: &SparseMatrix<f32>,
556    ) -> BackendResult<SparseMatrix<f32>> {
557        // This is a simplified implementation
558        // Real implementation would use more sophisticated algorithms
559        let mut result = SparseMatrix::new_coo(a.rows, b.cols);
560
561        for row_a in 0..a.rows {
562            let start_a = a.row_indices[row_a];
563            let end_a = a.row_indices[row_a + 1];
564
565            for idx_a in start_a..end_a {
566                let col_a = a.col_indices[idx_a];
567                let val_a = a.values[idx_a];
568
569                // col_a is the row in matrix B
570                let start_b = b.row_indices[col_a];
571                let end_b = b.row_indices[col_a + 1];
572
573                for idx_b in start_b..end_b {
574                    let col_b = b.col_indices[idx_b];
575                    let val_b = b.values[idx_b];
576
577                    let product = val_a * val_b;
578                    result.insert_coo(row_a, col_b, product)?;
579                }
580            }
581        }
582
583        Ok(result)
584    }
585
586    /// COO format sparse addition
587    fn sparse_add_coo(
588        &self,
589        a: &SparseMatrix<f32>,
590        b: &SparseMatrix<f32>,
591    ) -> BackendResult<SparseMatrix<f32>> {
592        let mut result = SparseMatrix::new_coo(a.rows, a.cols);
593
594        // Use a hashmap to combine duplicate entries
595        let mut entries: HashMap<(usize, usize), f32> = HashMap::new();
596
597        // Add entries from matrix A
598        for i in 0..a.nnz {
599            let key = (a.row_indices[i], a.col_indices[i]);
600            *entries.entry(key).or_insert(0.0) += a.values[i];
601        }
602
603        // Add entries from matrix B
604        for i in 0..b.nnz {
605            let key = (b.row_indices[i], b.col_indices[i]);
606            *entries.entry(key).or_insert(0.0) += b.values[i];
607        }
608
609        // Convert back to COO format
610        for ((row, col), value) in entries {
611            if value != 0.0 {
612                result.insert_coo(row, col, value)?;
613            }
614        }
615
616        Ok(result)
617    }
618}
619
620impl<T: Clone + Default + PartialEq> SparseMatrix<T> {
621    /// Create BSR (Block Sparse Row) matrix from COO
622    pub fn to_bsr(&self, block_size: (usize, usize)) -> BackendResult<SparseMatrix<T>> {
623        if self.format != SparseFormat::Coo {
624            return Err(TorshError::ComputeError(
625                "Source matrix must be in COO format".to_string(),
626            ));
627        }
628
629        let (block_rows, block_cols) = block_size;
630        if block_rows == 0 || block_cols == 0 {
631            return Err(TorshError::ComputeError(
632                "Block size must be positive".to_string(),
633            ));
634        }
635
636        // Calculate number of block rows and columns
637        let num_block_rows = (self.rows + block_rows - 1) / block_rows;
638        let _num_block_cols = (self.cols + block_cols - 1) / block_cols;
639
640        let mut bsr = SparseMatrix {
641            format: SparseFormat::Bsr,
642            rows: self.rows,
643            cols: self.cols,
644            nnz: 0,
645            values: Vec::new(),
646            row_indices: vec![0; num_block_rows + 1], // block row pointers
647            col_indices: Vec::new(),                  // block column indices
648            block_size: Some(block_size),
649        };
650
651        // Group non-zeros by blocks
652        let mut blocks: HashMap<(usize, usize), Vec<T>> = HashMap::new();
653
654        for i in 0..self.nnz {
655            let row = self.row_indices[i];
656            let col = self.col_indices[i];
657            let val = self.values[i].clone();
658
659            let block_row = row / block_rows;
660            let block_col = col / block_cols;
661            let in_block_row = row % block_rows;
662            let in_block_col = col % block_cols;
663
664            let block_entry = blocks
665                .entry((block_row, block_col))
666                .or_insert_with(|| vec![T::default(); block_rows * block_cols]);
667            block_entry[in_block_row * block_cols + in_block_col] = val;
668        }
669
670        // Convert to BSR format
671        let mut sorted_blocks: Vec<_> = blocks.into_iter().collect();
672        sorted_blocks.sort_by_key(|&((br, bc), _)| (br, bc));
673
674        let mut current_block_row = 0;
675        for ((block_row, block_col), block_values) in sorted_blocks {
676            // Update row pointers
677            while current_block_row < block_row {
678                current_block_row += 1;
679                bsr.row_indices[current_block_row] = bsr.col_indices.len();
680            }
681
682            // Add block
683            bsr.col_indices.push(block_col);
684            bsr.values.extend(block_values);
685            bsr.nnz += 1; // Number of blocks, not individual elements
686        }
687
688        // Fill remaining row pointers
689        let final_ptr = bsr.col_indices.len();
690        for i in (current_block_row + 1)..=num_block_rows {
691            bsr.row_indices[i] = final_ptr;
692        }
693
694        Ok(bsr)
695    }
696
697    /// Optimize matrix structure by removing explicit zeros and sorting
698    pub fn optimize(&mut self) -> BackendResult<()> {
699        match self.format {
700            SparseFormat::Coo => {
701                // Remove explicit zeros and sort by (row, col)
702                let mut triplets: Vec<(usize, usize, T)> = (0..self.nnz)
703                    .filter_map(|i| {
704                        let val = &self.values[i];
705                        if *val != T::default() {
706                            Some((self.row_indices[i], self.col_indices[i], val.clone()))
707                        } else {
708                            None
709                        }
710                    })
711                    .collect();
712
713                triplets.sort_by(|a, b| a.0.cmp(&b.0).then(a.1.cmp(&b.1)));
714
715                // Rebuild arrays
716                self.nnz = triplets.len();
717                self.row_indices.clear();
718                self.col_indices.clear();
719                self.values.clear();
720
721                for (row, col, val) in triplets {
722                    self.row_indices.push(row);
723                    self.col_indices.push(col);
724                    self.values.push(val);
725                }
726            }
727            SparseFormat::Csr | SparseFormat::Csc => {
728                // Remove explicit zeros while maintaining format
729                let mut new_values = Vec::new();
730                let mut new_col_indices = Vec::new();
731                let mut new_row_pointers = vec![0];
732
733                let num_rows = if self.format == SparseFormat::Csr {
734                    self.rows
735                } else {
736                    self.cols
737                };
738
739                for row in 0..num_rows {
740                    let start = self.row_indices[row];
741                    let end = self.row_indices[row + 1];
742
743                    for idx in start..end {
744                        if self.values[idx] != T::default() {
745                            new_values.push(self.values[idx].clone());
746                            new_col_indices.push(self.col_indices[idx]);
747                        }
748                    }
749                    new_row_pointers.push(new_values.len());
750                }
751
752                self.values = new_values;
753                self.col_indices = new_col_indices;
754                self.row_indices = new_row_pointers;
755                self.nnz = self.values.len();
756            }
757            _ => {
758                return Err(TorshError::ComputeError(
759                    "Optimization not supported for this format".to_string(),
760                ))
761            }
762        }
763
764        Ok(())
765    }
766
767    /// Get matrix statistics for performance analysis
768    pub fn statistics(&self) -> SparseMatrixStatistics {
769        let mut max_row_nnz = 0;
770        let mut min_row_nnz = usize::MAX;
771        let mut row_nnz_variance = 0.0;
772
773        match self.format {
774            SparseFormat::Csr => {
775                let mut row_counts = Vec::new();
776                for row in 0..self.rows {
777                    let count = self.row_indices[row + 1] - self.row_indices[row];
778                    row_counts.push(count);
779                    max_row_nnz = max_row_nnz.max(count);
780                    min_row_nnz = min_row_nnz.min(count);
781                }
782
783                let mean = row_counts.iter().sum::<usize>() as f64 / row_counts.len() as f64;
784                row_nnz_variance = row_counts
785                    .iter()
786                    .map(|&x| (x as f64 - mean).powi(2))
787                    .sum::<f64>()
788                    / row_counts.len() as f64;
789            }
790            SparseFormat::Coo => {
791                let mut row_counts = vec![0; self.rows];
792                for &row in &self.row_indices {
793                    row_counts[row] += 1;
794                }
795                max_row_nnz = *row_counts.iter().max().unwrap_or(&0);
796                min_row_nnz = *row_counts.iter().min().unwrap_or(&0);
797
798                let mean = self.nnz as f64 / self.rows as f64;
799                row_nnz_variance = row_counts
800                    .iter()
801                    .map(|&x| (x as f64 - mean).powi(2))
802                    .sum::<f64>()
803                    / self.rows as f64;
804            }
805            _ => {
806                // For other formats, provide basic stats
807                min_row_nnz = if self.nnz == 0 { 0 } else { 1 };
808            }
809        }
810
811        SparseMatrixStatistics {
812            format: self.format,
813            rows: self.rows,
814            cols: self.cols,
815            nnz: self.nnz,
816            sparsity_ratio: self.sparsity_ratio(),
817            max_row_nnz,
818            min_row_nnz,
819            row_nnz_variance,
820            memory_usage: self.estimated_memory_usage(),
821        }
822    }
823
824    /// Estimate memory usage in bytes
825    fn estimated_memory_usage(&self) -> usize {
826        SparseFormatConverter::estimate_memory_usage(self.rows, self.cols, self.nnz, self.format)
827    }
828}
829
830/// Optimization hints for sparse operations
831#[derive(Debug, Clone)]
832pub struct SparseOptimizationHints {
833    /// Prefer memory efficiency over speed
834    pub memory_efficient: bool,
835    /// Use parallel processing when available
836    pub use_parallel: bool,
837    /// Expected sparsity level (0.0 to 1.0)
838    pub expected_sparsity: f64,
839    /// Block size for BSR format operations
840    pub block_size: Option<(usize, usize)>,
841    /// Cache block size for tiled operations
842    pub cache_block_size: usize,
843}
844
845impl Default for SparseOptimizationHints {
846    fn default() -> Self {
847        Self {
848            memory_efficient: true,
849            use_parallel: true,
850            expected_sparsity: 0.1, // 10% non-zero by default
851            block_size: None,
852            cache_block_size: 64,
853        }
854    }
855}
856
857/// Sparse format conversion utilities
858pub struct SparseFormatConverter;
859
860impl SparseFormatConverter {
861    /// Automatically choose the best format based on matrix properties
862    pub fn choose_optimal_format<T>(
863        _matrix: &SparseMatrix<T>,
864        operation: SparseOperation,
865    ) -> SparseFormat {
866        match operation {
867            SparseOperation::SpMV => {
868                // CSR is typically best for SpMV
869                SparseFormat::Csr
870            }
871            SparseOperation::SpMM => {
872                // CSR x CSC is often efficient for SpMM
873                SparseFormat::Csr
874            }
875            SparseOperation::Addition => {
876                // COO is easiest for addition
877                SparseFormat::Coo
878            }
879            SparseOperation::Transpose => {
880                // COO is format-agnostic for transpose
881                SparseFormat::Coo
882            }
883            SparseOperation::Iterative => {
884                // CSR is good for iterative methods
885                SparseFormat::Csr
886            }
887        }
888    }
889
890    /// Get memory usage estimate for different formats
891    pub fn estimate_memory_usage(
892        rows: usize,
893        cols: usize,
894        nnz: usize,
895        format: SparseFormat,
896    ) -> usize {
897        match format {
898            SparseFormat::Coo => {
899                // (row_idx, col_idx, value) for each non-zero
900                nnz * (std::mem::size_of::<usize>() * 2 + std::mem::size_of::<f32>())
901            }
902            SparseFormat::Csr => {
903                // row_ptr array + col_indices + values
904                (rows + 1) * std::mem::size_of::<usize>()
905                    + nnz * (std::mem::size_of::<usize>() + std::mem::size_of::<f32>())
906            }
907            SparseFormat::Csc => {
908                // col_ptr array + row_indices + values
909                (cols + 1) * std::mem::size_of::<usize>()
910                    + nnz * (std::mem::size_of::<usize>() + std::mem::size_of::<f32>())
911            }
912            SparseFormat::Dense => rows * cols * std::mem::size_of::<f32>(),
913            _ => nnz * std::mem::size_of::<f32>() * 3, // Conservative estimate
914        }
915    }
916}
917
918/// Types of sparse operations for format optimization
919#[derive(Debug, Clone, Copy, PartialEq, Eq)]
920pub enum SparseOperation {
921    /// Sparse matrix-vector multiplication
922    SpMV,
923    /// Sparse matrix-matrix multiplication
924    SpMM,
925    /// Matrix addition
926    Addition,
927    /// Matrix transpose
928    Transpose,
929    /// Iterative solver operations
930    Iterative,
931}
932
933/// Statistics for sparse matrix analysis and optimization
934#[derive(Debug, Clone)]
935pub struct SparseMatrixStatistics {
936    /// Matrix storage format
937    pub format: SparseFormat,
938    /// Number of rows
939    pub rows: usize,
940    /// Number of columns
941    pub cols: usize,
942    /// Number of non-zero elements
943    pub nnz: usize,
944    /// Sparsity ratio (0.0 to 1.0)
945    pub sparsity_ratio: f64,
946    /// Maximum non-zeros in any row
947    pub max_row_nnz: usize,
948    /// Minimum non-zeros in any row
949    pub min_row_nnz: usize,
950    /// Variance in row non-zero counts
951    pub row_nnz_variance: f64,
952    /// Estimated memory usage in bytes
953    pub memory_usage: usize,
954}
955
956impl SparseMatrixStatistics {
957    /// Check if matrix structure is well-balanced
958    pub fn is_well_balanced(&self) -> bool {
959        if self.rows == 0 || self.nnz == 0 {
960            return true;
961        }
962
963        let avg_nnz_per_row = self.nnz as f64 / self.rows as f64;
964        let balance_ratio = self.max_row_nnz as f64 / avg_nnz_per_row.max(1.0);
965
966        // Consider well-balanced if max row doesn't have more than 3x average
967        balance_ratio < 3.0
968    }
969
970    /// Get recommended operations based on matrix characteristics
971    pub fn recommended_operations(&self) -> Vec<&'static str> {
972        let mut recommendations = Vec::new();
973
974        if self.sparsity_ratio < 0.1 {
975            recommendations.push("Very sparse - excellent for sparse algorithms");
976        } else if self.sparsity_ratio > 0.5 {
977            recommendations.push("Dense - consider dense algorithms");
978        }
979
980        if !self.is_well_balanced() {
981            recommendations.push("Unbalanced structure - consider load balancing");
982        }
983
984        match self.format {
985            SparseFormat::Coo => {
986                recommendations.push("COO format - good for construction and element access")
987            }
988            SparseFormat::Csr => {
989                recommendations.push("CSR format - optimal for SpMV and most algorithms")
990            }
991            SparseFormat::Csc => recommendations.push("CSC format - good for transpose operations"),
992            SparseFormat::Bsr => {
993                recommendations.push("BSR format - optimal for block-structured sparsity")
994            }
995            SparseFormat::Dense => recommendations.push("Dense format - use dense linear algebra"),
996        }
997
998        recommendations
999    }
1000}
1001
1002/// Hardware-accelerated sparse operations (backend-specific implementations)
1003pub trait HardwareSparseOps<T>: SparseOps<T> {
1004    /// Get hardware acceleration capabilities
1005    fn acceleration_capabilities(&self) -> SparseAccelerationCapabilities;
1006
1007    /// Batched sparse matrix-vector multiplication
1008    fn batch_spmv(
1009        &self,
1010        matrices: &[&SparseMatrix<T>],
1011        vectors: &[&[T]],
1012        results: &mut [&mut [T]],
1013    ) -> BackendResult<()>;
1014
1015    /// Fused sparse operations (e.g., SpMV + vector operations)
1016    fn fused_spmv_add(
1017        &self,
1018        matrix: &SparseMatrix<T>,
1019        x: &[T],
1020        y: &[T],
1021        result: &mut [T],
1022        alpha: T,
1023        beta: T,
1024    ) -> BackendResult<()>;
1025
1026    /// Sparse iterative solver operations
1027    fn iterative_solve(
1028        &self,
1029        matrix: &SparseMatrix<T>,
1030        b: &[T],
1031        x0: &[T],
1032        method: IterativeMethod,
1033        tolerance: f64,
1034        max_iterations: usize,
1035    ) -> BackendResult<SolverResult<T>>;
1036}
1037
1038/// Hardware acceleration capabilities for sparse operations
1039#[derive(Debug, Clone)]
1040pub struct SparseAccelerationCapabilities {
1041    /// SIMD vector instructions available
1042    pub simd_width: usize,
1043    /// GPU acceleration available
1044    pub gpu_acceleration: bool,
1045    /// Specialized sparse hardware (e.g., tensor cores)
1046    pub specialized_hardware: bool,
1047    /// Multi-threading support
1048    pub parallel_execution: bool,
1049    /// Memory bandwidth (GB/s)
1050    pub memory_bandwidth: f32,
1051}
1052
1053/// Iterative solver methods for sparse linear systems
1054#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1055pub enum IterativeMethod {
1056    /// Conjugate Gradient (for symmetric positive definite)
1057    ConjugateGradient,
1058    /// BiCGStab (for general systems)
1059    BiCGStab,
1060    /// GMRES (for general systems)
1061    GMRES,
1062    /// Jacobi iteration
1063    Jacobi,
1064    /// Gauss-Seidel iteration
1065    GaussSeidel,
1066}
1067
1068/// Result from iterative solver
1069#[derive(Debug, Clone)]
1070pub struct SolverResult<T> {
1071    /// Solution vector
1072    pub solution: Vec<T>,
1073    /// Number of iterations performed
1074    pub iterations: usize,
1075    /// Final residual norm
1076    pub residual_norm: f64,
1077    /// Whether solver converged
1078    pub converged: bool,
1079    /// Error message if failed
1080    pub error_message: Option<String>,
1081}
1082
1083/// Advanced sparse operations with optimization and acceleration
1084#[derive(Debug)]
1085pub struct AdvancedSparseOps {
1086    base_ops: DefaultSparseOps,
1087    acceleration_caps: SparseAccelerationCapabilities,
1088    performance_cache: HashMap<String, f64>, // Cache for operation timings
1089}
1090
1091impl AdvancedSparseOps {
1092    /// Create new advanced sparse operations instance
1093    pub fn new(device: Device) -> Self {
1094        let acceleration_caps = Self::detect_acceleration_capabilities(&device);
1095        let base_ops = DefaultSparseOps::new(device);
1096
1097        Self {
1098            base_ops,
1099            acceleration_caps,
1100            performance_cache: HashMap::new(),
1101        }
1102    }
1103
1104    /// Detect hardware acceleration capabilities
1105    fn detect_acceleration_capabilities(device: &Device) -> SparseAccelerationCapabilities {
1106        SparseAccelerationCapabilities {
1107            simd_width: if cfg!(target_arch = "x86_64") { 8 } else { 4 }, // AVX-256 vs NEON-128
1108            gpu_acceleration: device.device_type() != torsh_core::device::DeviceType::Cpu,
1109            specialized_hardware: false, // Would be detected based on device
1110            parallel_execution: true,
1111            memory_bandwidth: if device.device_type() == torsh_core::device::DeviceType::Cpu {
1112                50.0
1113            } else {
1114                500.0
1115            },
1116        }
1117    }
1118
1119    /// Optimized sparse matrix-vector multiplication with acceleration
1120    pub fn optimized_spmv(
1121        &mut self,
1122        matrix: &SparseMatrix<f32>,
1123        x: &[f32],
1124        y: &mut [f32],
1125    ) -> BackendResult<()> {
1126        let _cache_key = format!(
1127            "spmv_{}_{}_{}_{}",
1128            matrix.format as u8, matrix.rows, matrix.cols, matrix.nnz
1129        );
1130
1131        // Use parallel execution if available and matrix is large enough
1132        if self.acceleration_caps.parallel_execution && matrix.nnz > 10000 {
1133            self.parallel_spmv(matrix, x, y)
1134        } else if self.acceleration_caps.simd_width > 1 {
1135            self.simd_spmv(matrix, x, y)
1136        } else {
1137            self.base_ops.spmv(matrix, x, y)
1138        }
1139    }
1140
1141    /// Parallel sparse matrix-vector multiplication
1142    fn parallel_spmv(
1143        &self,
1144        matrix: &SparseMatrix<f32>,
1145        x: &[f32],
1146        y: &mut [f32],
1147    ) -> BackendResult<()> {
1148        match matrix.format {
1149            SparseFormat::Csr => {
1150                // ✅ SciRS2 POLICY: Use scirs2_core::parallel_ops instead of direct rayon
1151                use scirs2_core::parallel_ops::*;
1152
1153                // Parallel iteration over rows
1154                let row_chunks: Vec<_> = (0..matrix.rows).collect();
1155                let chunk_size = (matrix.rows + current_num_threads() - 1) / current_num_threads();
1156
1157                row_chunks.par_chunks(chunk_size).for_each(|chunk| {
1158                    for &row in chunk {
1159                        let start = matrix.row_indices[row];
1160                        let end = matrix.row_indices[row + 1];
1161                        let mut sum = 0.0;
1162
1163                        for idx in start..end {
1164                            let col = matrix.col_indices[idx];
1165                            let val = matrix.values[idx];
1166                            sum += val * x[col];
1167                        }
1168
1169                        // Safe because each thread works on disjoint rows
1170                        unsafe {
1171                            let y_ptr = y.as_ptr() as *mut f32;
1172                            *y_ptr.add(row) = sum;
1173                        }
1174                    }
1175                });
1176
1177                Ok(())
1178            }
1179            _ => self.base_ops.spmv(matrix, x, y), // Fall back to sequential for other formats
1180        }
1181    }
1182
1183    /// SIMD-accelerated sparse matrix-vector multiplication
1184    fn simd_spmv(&self, matrix: &SparseMatrix<f32>, x: &[f32], y: &mut [f32]) -> BackendResult<()> {
1185        // Placeholder for SIMD implementation
1186        // Real implementation would use architecture-specific SIMD intrinsics
1187        self.base_ops.spmv(matrix, x, y)
1188    }
1189
1190    /// Adaptive format selection based on operation and matrix characteristics
1191    pub fn adaptive_format_conversion(
1192        &self,
1193        matrix: &SparseMatrix<f32>,
1194        target_operation: SparseOperation,
1195    ) -> BackendResult<SparseMatrix<f32>> {
1196        let stats = matrix.statistics();
1197
1198        let optimal_format = if stats.is_well_balanced() {
1199            match target_operation {
1200                SparseOperation::SpMV => SparseFormat::Csr,
1201                SparseOperation::SpMM => SparseFormat::Csr,
1202                SparseOperation::Addition => SparseFormat::Coo,
1203                SparseOperation::Transpose => SparseFormat::Coo,
1204                SparseOperation::Iterative => SparseFormat::Csr,
1205            }
1206        } else {
1207            // For unbalanced matrices, prefer more flexible formats
1208            match target_operation {
1209                SparseOperation::SpMV if stats.max_row_nnz > stats.nnz / 10 => SparseFormat::Coo, // Very unbalanced
1210                _ => SparseFormatConverter::choose_optimal_format(matrix, target_operation),
1211            }
1212        };
1213
1214        if matrix.format == optimal_format {
1215            Ok(matrix.clone())
1216        } else {
1217            match optimal_format {
1218                SparseFormat::Csr => matrix.to_csr(),
1219                SparseFormat::Csc => matrix.to_csc(),
1220                SparseFormat::Bsr => {
1221                    let block_size = (8, 8); // Default block size
1222                    matrix.to_bsr(block_size)
1223                }
1224                _ => Ok(matrix.clone()),
1225            }
1226        }
1227    }
1228
1229    /// Benchmark and cache operation performance
1230    pub fn benchmark_operation(&mut self, operation: &str, matrix: &SparseMatrix<f32>) -> f64 {
1231        let cache_key = format!(
1232            "{}_{}_{}_{}",
1233            operation, matrix.rows, matrix.cols, matrix.nnz
1234        );
1235
1236        if let Some(&cached_time) = self.performance_cache.get(&cache_key) {
1237            return cached_time;
1238        }
1239
1240        // Simplified benchmarking - real implementation would use precise timing
1241        let estimated_time = match operation {
1242            "spmv" => {
1243                (matrix.nnz as f64 * 2.0) / (self.acceleration_caps.memory_bandwidth as f64 * 1e9)
1244            }
1245            "spmm" => {
1246                (matrix.nnz as f64 * matrix.cols as f64 * 2.0)
1247                    / (self.acceleration_caps.memory_bandwidth as f64 * 1e9)
1248            }
1249            _ => 0.001, // Default 1ms
1250        };
1251
1252        self.performance_cache.insert(cache_key, estimated_time);
1253        estimated_time
1254    }
1255}
1256
1257/// Sparse linear algebra utilities
1258pub struct SparseLinAlg;
1259
1260impl SparseLinAlg {
1261    /// Compute sparse matrix norm (Frobenius norm)
1262    pub fn frobenius_norm<T>(matrix: &SparseMatrix<T>) -> f64
1263    where
1264        T: Clone + Default + PartialEq,
1265        f64: From<T>,
1266    {
1267        let mut sum = 0.0;
1268        for value in &matrix.values {
1269            let val: f64 = value.clone().into();
1270            sum += val * val;
1271        }
1272        sum.sqrt()
1273    }
1274
1275    /// Check if sparse matrix is symmetric
1276    pub fn is_symmetric(matrix: &SparseMatrix<f32>, tolerance: f32) -> BackendResult<bool> {
1277        if matrix.rows != matrix.cols {
1278            return Ok(false);
1279        }
1280
1281        // Convert to COO for easier comparison
1282        let coo = if matrix.format == SparseFormat::Coo {
1283            matrix.clone()
1284        } else {
1285            return Err(TorshError::ComputeError(
1286                "Symmetry check requires COO format".to_string(),
1287            ));
1288        };
1289
1290        // Create a map of (row, col) -> value for quick lookup
1291        let mut entries: HashMap<(usize, usize), f32> = HashMap::new();
1292        for i in 0..coo.nnz {
1293            entries.insert((coo.row_indices[i], coo.col_indices[i]), coo.values[i]);
1294        }
1295
1296        // Check symmetry
1297        for ((row, col), &value) in &entries {
1298            if let Some(&transpose_value) = entries.get(&(*col, *row)) {
1299                if (value - transpose_value).abs() > tolerance {
1300                    return Ok(false);
1301                }
1302            } else if value.abs() > tolerance {
1303                // Non-zero element with no transpose counterpart
1304                return Ok(false);
1305            }
1306        }
1307
1308        Ok(true)
1309    }
1310
1311    /// Extract diagonal of sparse matrix
1312    pub fn diagonal<T: Clone + Default>(matrix: &SparseMatrix<T>) -> Vec<T> {
1313        let mut diag = vec![T::default(); matrix.rows.min(matrix.cols)];
1314
1315        match matrix.format {
1316            SparseFormat::Coo => {
1317                for i in 0..matrix.nnz {
1318                    let row = matrix.row_indices[i];
1319                    let col = matrix.col_indices[i];
1320                    if row == col && row < diag.len() {
1321                        diag[row] = matrix.values[i].clone();
1322                    }
1323                }
1324            }
1325            SparseFormat::Csr => {
1326                for row in 0..matrix.rows.min(diag.len()) {
1327                    let start = matrix.row_indices[row];
1328                    let end = matrix.row_indices[row + 1];
1329
1330                    for idx in start..end {
1331                        let col = matrix.col_indices[idx];
1332                        if col == row {
1333                            diag[row] = matrix.values[idx].clone();
1334                            break;
1335                        } else if col > row {
1336                            break; // Assuming sorted columns
1337                        }
1338                    }
1339                }
1340            }
1341            _ => {
1342                // For other formats, convert to COO first (simplified)
1343                // Real implementation would handle each format optimally
1344            }
1345        }
1346
1347        diag
1348    }
1349}
1350
1351#[cfg(test)]
1352mod tests {
1353    use super::*;
1354
1355    #[test]
1356    fn test_sparse_matrix_creation() {
1357        let mut matrix = SparseMatrix::<f32>::new_coo(3, 3);
1358        assert_eq!(matrix.rows, 3);
1359        assert_eq!(matrix.cols, 3);
1360        assert_eq!(matrix.nnz, 0);
1361        assert_eq!(matrix.format, SparseFormat::Coo);
1362
1363        // Insert some values
1364        matrix.insert_coo(0, 0, 1.0).unwrap();
1365        matrix.insert_coo(1, 1, 2.0).unwrap();
1366        matrix.insert_coo(2, 2, 3.0).unwrap();
1367
1368        assert_eq!(matrix.nnz, 3);
1369        assert_eq!(matrix.sparsity_ratio(), 3.0 / 9.0);
1370        assert!(matrix.is_sparse());
1371    }
1372
1373    #[test]
1374    fn test_coo_to_csr_conversion() {
1375        let mut coo = SparseMatrix::<f32>::new_coo(3, 3);
1376        coo.insert_coo(0, 0, 1.0).unwrap();
1377        coo.insert_coo(0, 2, 2.0).unwrap();
1378        coo.insert_coo(1, 1, 3.0).unwrap();
1379        coo.insert_coo(2, 0, 4.0).unwrap();
1380        coo.insert_coo(2, 2, 5.0).unwrap();
1381
1382        let csr = coo.to_csr().unwrap();
1383        assert_eq!(csr.format, SparseFormat::Csr);
1384        assert_eq!(csr.nnz, 5);
1385
1386        // Check row_ptr array: should be [0, 2, 3, 5]
1387        assert_eq!(csr.row_indices, vec![0, 2, 3, 5]);
1388    }
1389
1390    #[test]
1391    fn test_sparse_spmv() {
1392        let device = Device::cpu().unwrap();
1393        let sparse_ops = DefaultSparseOps::new(device);
1394
1395        // Create a simple 3x3 matrix
1396        let mut matrix = SparseMatrix::<f32>::new_coo(3, 3);
1397        matrix.insert_coo(0, 0, 2.0).unwrap();
1398        matrix.insert_coo(1, 1, 3.0).unwrap();
1399        matrix.insert_coo(2, 2, 4.0).unwrap();
1400
1401        // Convert to CSR for SpMV
1402        let csr_matrix = matrix.to_csr().unwrap();
1403
1404        let x = vec![1.0, 2.0, 3.0];
1405        let mut y = vec![0.0; 3];
1406
1407        sparse_ops.spmv(&csr_matrix, &x, &mut y).unwrap();
1408
1409        assert_eq!(y, vec![2.0, 6.0, 12.0]);
1410    }
1411
1412    #[test]
1413    fn test_sparse_to_dense() {
1414        let device = Device::cpu().unwrap();
1415        let sparse_ops = DefaultSparseOps::new(device);
1416
1417        let mut matrix = SparseMatrix::<f32>::new_coo(2, 2);
1418        matrix.insert_coo(0, 0, 1.0).unwrap();
1419        matrix.insert_coo(1, 1, 2.0).unwrap();
1420
1421        let dense = sparse_ops.to_dense(&matrix).unwrap();
1422        assert_eq!(dense, vec![1.0, 0.0, 0.0, 2.0]);
1423    }
1424
1425    #[test]
1426    fn test_sparse_from_dense() {
1427        let device = Device::cpu().unwrap();
1428        let sparse_ops = DefaultSparseOps::new(device);
1429
1430        let dense = vec![1.0, 0.0, 0.0, 2.0];
1431        let sparse = sparse_ops.from_dense(&dense, 2, 2, 0.1).unwrap();
1432
1433        assert_eq!(sparse.nnz, 2);
1434        assert_eq!(sparse.sparsity_ratio(), 0.5);
1435    }
1436
1437    #[test]
1438    fn test_sparse_transpose() {
1439        let device = Device::cpu().unwrap();
1440        let sparse_ops = DefaultSparseOps::new(device);
1441
1442        let mut matrix = SparseMatrix::<f32>::new_coo(2, 3);
1443        matrix.insert_coo(0, 1, 1.0).unwrap();
1444        matrix.insert_coo(1, 2, 2.0).unwrap();
1445
1446        let transposed = sparse_ops.transpose(&matrix).unwrap();
1447
1448        assert_eq!(transposed.rows, 3);
1449        assert_eq!(transposed.cols, 2);
1450        assert_eq!(transposed.nnz, 2);
1451
1452        // Check that indices are swapped
1453        assert_eq!(transposed.row_indices, vec![1, 2]); // original col_indices
1454        assert_eq!(transposed.col_indices, vec![0, 1]); // original row_indices
1455    }
1456
1457    #[test]
1458    fn test_memory_usage_estimation() {
1459        let rows = 1000;
1460        let cols = 1000;
1461        let nnz = 10000; // 1% sparsity
1462
1463        let coo_memory =
1464            SparseFormatConverter::estimate_memory_usage(rows, cols, nnz, SparseFormat::Coo);
1465        let csr_memory =
1466            SparseFormatConverter::estimate_memory_usage(rows, cols, nnz, SparseFormat::Csr);
1467        let dense_memory =
1468            SparseFormatConverter::estimate_memory_usage(rows, cols, nnz, SparseFormat::Dense);
1469
1470        // Dense should use much more memory
1471        assert!(dense_memory > coo_memory);
1472        assert!(dense_memory > csr_memory);
1473
1474        // CSR should be slightly more memory efficient than COO for this case
1475        assert!(csr_memory < coo_memory);
1476    }
1477
1478    #[test]
1479    fn test_format_selection() {
1480        let matrix = SparseMatrix::<f32>::new_coo(100, 100);
1481
1482        let spmv_format =
1483            SparseFormatConverter::choose_optimal_format(&matrix, SparseOperation::SpMV);
1484        assert_eq!(spmv_format, SparseFormat::Csr);
1485
1486        let add_format =
1487            SparseFormatConverter::choose_optimal_format(&matrix, SparseOperation::Addition);
1488        assert_eq!(add_format, SparseFormat::Coo);
1489    }
1490}