Skip to main content

tensorlogic_sklears_kernels/
sparse.rs

1//! Sparse kernel matrix support for large-scale problems.
2//!
3//! Provides efficient storage and operations for sparse kernel matrices using
4//! Compressed Sparse Row (CSR) format for memory-efficient representation.
5//!
6//! # Features
7//!
8//! - **Efficient Storage**: CSR format for sparse matrices with configurable thresholds
9//! - **Matrix Operations**: SpMV, transpose, addition, scaling, Frobenius norm
10//! - **Parallel Construction**: Multi-threaded matrix building with rayon
11//! - **Iterator Support**: Efficient iteration over non-zero entries
12//! - **Flexible Builders**: Configurable threshold and max entries per row
13//!
14//! # Example
15//!
16//! ```rust
17//! use tensorlogic_sklears_kernels::{SparseKernelMatrix, SparseKernelMatrixBuilder};
18//! use tensorlogic_sklears_kernels::tensor_kernels::LinearKernel;
19//!
20//! // Build a sparse kernel matrix with parallel computation
21//! let builder = SparseKernelMatrixBuilder::new()
22//!     .with_threshold(0.1).unwrap()
23//!     .with_max_entries_per_row(100).unwrap();
24//!
25//! let kernel = LinearKernel::new();
26//! let data = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
27//! let matrix = builder.build_parallel(&data, &kernel).unwrap();
28//!
29//! // Sparse matrix-vector multiplication
30//! let mut matrix = SparseKernelMatrix::new(3);
31//! matrix.set(0, 0, 2.0);
32//! matrix.set(1, 1, 3.0);
33//! let x = vec![1.0, 2.0, 0.0];
34//! let y = matrix.spmv(&x).unwrap();
35//!
36//! // Iterate over non-zero entries
37//! for (row, col, value) in matrix.iter_nonzeros() {
38//!     println!("({}, {}) = {}", row, col, value);
39//! }
40//! ```
41
42use std::collections::HashMap;
43
44use serde::{Deserialize, Serialize};
45
46use crate::error::{KernelError, Result};
47use crate::types::Kernel;
48
49/// Sparse kernel matrix using Compressed Sparse Row (CSR) format
50///
51/// Stores only non-zero entries for efficient memory usage.
52///
53/// # Example
54///
55/// ```rust
56/// use tensorlogic_sklears_kernels::SparseKernelMatrix;
57///
58/// let mut matrix = SparseKernelMatrix::new(3);
59/// matrix.set(0, 1, 0.8);
60/// matrix.set(1, 2, 0.6);
61///
62/// assert_eq!(matrix.get(0, 1), Some(0.8));
63/// assert_eq!(matrix.get(0, 2), None);
64/// assert_eq!(matrix.nnz(), 2);
65/// ```
66#[derive(Clone, Debug, Serialize, Deserialize)]
67pub struct SparseKernelMatrix {
68    /// Number of rows/columns (square matrix)
69    size: usize,
70    /// Row pointers for CSR format
71    row_ptr: Vec<usize>,
72    /// Column indices
73    col_idx: Vec<usize>,
74    /// Non-zero values
75    values: Vec<f64>,
76    /// Temporary map for construction (not serialized)
77    #[serde(skip)]
78    temp_map: HashMap<(usize, usize), f64>,
79}
80
81impl SparseKernelMatrix {
82    /// Create a new sparse kernel matrix
83    pub fn new(size: usize) -> Self {
84        Self {
85            size,
86            row_ptr: vec![0; size + 1],
87            col_idx: Vec::new(),
88            values: Vec::new(),
89            temp_map: HashMap::new(),
90        }
91    }
92
93    /// Set a value in the matrix
94    pub fn set(&mut self, row: usize, col: usize, value: f64) {
95        if row >= self.size || col >= self.size {
96            return;
97        }
98
99        if value.abs() < 1e-10 {
100            // Remove near-zero values
101            self.temp_map.remove(&(row, col));
102        } else {
103            self.temp_map.insert((row, col), value);
104        }
105    }
106
107    /// Get a value from the matrix
108    pub fn get(&self, row: usize, col: usize) -> Option<f64> {
109        if row >= self.size || col >= self.size {
110            return None;
111        }
112
113        // Check temp map first
114        if let Some(&value) = self.temp_map.get(&(row, col)) {
115            return Some(value);
116        }
117
118        // Search in CSR format
119        let start = self.row_ptr[row];
120        let end = self.row_ptr[row + 1];
121
122        for i in start..end {
123            if self.col_idx[i] == col {
124                return Some(self.values[i]);
125            }
126        }
127
128        None
129    }
130
131    /// Finalize the matrix (convert temp map to CSR format)
132    pub fn finalize(&mut self) {
133        if self.temp_map.is_empty() {
134            return;
135        }
136
137        // Clear existing CSR data
138        self.col_idx.clear();
139        self.values.clear();
140        self.row_ptr = vec![0; self.size + 1];
141
142        // Sort entries by row, then column
143        let mut entries: Vec<_> = self.temp_map.iter().collect();
144        entries.sort_by_key(|&((row, col), _)| (*row, *col));
145
146        // Build CSR format
147        let mut current_row = 0;
148        for (&(row, col), &value) in &entries {
149            // Update row pointers
150            while current_row < row {
151                current_row += 1;
152                self.row_ptr[current_row] = self.col_idx.len();
153            }
154
155            self.col_idx.push(col);
156            self.values.push(value);
157        }
158
159        // Finalize row pointers
160        while current_row < self.size {
161            current_row += 1;
162            self.row_ptr[current_row] = self.col_idx.len();
163        }
164
165        // Clear temp map
166        self.temp_map.clear();
167    }
168
169    /// Get number of non-zero entries
170    pub fn nnz(&self) -> usize {
171        self.values.len() + self.temp_map.len()
172    }
173
174    /// Get matrix size
175    pub fn size(&self) -> usize {
176        self.size
177    }
178
179    /// Get density (fraction of non-zero entries)
180    pub fn density(&self) -> f64 {
181        let total = self.size * self.size;
182        if total == 0 {
183            0.0
184        } else {
185            self.nnz() as f64 / total as f64
186        }
187    }
188
189    /// Convert to dense matrix
190    #[allow(clippy::needless_range_loop)]
191    pub fn to_dense(&mut self) -> Vec<Vec<f64>> {
192        self.finalize();
193
194        let mut dense = vec![vec![0.0; self.size]; self.size];
195
196        for row in 0..self.size {
197            let start = self.row_ptr[row];
198            let end = self.row_ptr[row + 1];
199
200            for i in start..end {
201                let col = self.col_idx[i];
202                let value = self.values[i];
203                dense[row][col] = value;
204            }
205        }
206
207        dense
208    }
209
210    /// Compute sparse kernel matrix from data with threshold
211    pub fn from_kernel_with_threshold(
212        data: &[Vec<f64>],
213        kernel: &dyn Kernel,
214        threshold: f64,
215    ) -> Result<Self> {
216        let n = data.len();
217        let mut matrix = Self::new(n);
218
219        for i in 0..n {
220            for j in 0..n {
221                let value = kernel.compute(&data[i], &data[j])?;
222                if value.abs() >= threshold {
223                    matrix.set(i, j, value);
224                }
225            }
226        }
227
228        matrix.finalize();
229        Ok(matrix)
230    }
231
232    /// Get row as sparse vector
233    pub fn row(&mut self, row_idx: usize) -> Option<Vec<(usize, f64)>> {
234        if row_idx >= self.size {
235            return None;
236        }
237
238        self.finalize();
239
240        let start = self.row_ptr[row_idx];
241        let end = self.row_ptr[row_idx + 1];
242
243        let mut row_data = Vec::new();
244        for i in start..end {
245            row_data.push((self.col_idx[i], self.values[i]));
246        }
247
248        Some(row_data)
249    }
250}
251
252/// Sparse kernel matrix builder with configuration
253pub struct SparseKernelMatrixBuilder {
254    /// Sparsity threshold (values below this are treated as zero)
255    threshold: f64,
256    /// Maximum entries per row (for controlled sparsity)
257    max_entries_per_row: Option<usize>,
258}
259
260impl SparseKernelMatrixBuilder {
261    /// Create a new builder
262    pub fn new() -> Self {
263        Self {
264            threshold: 1e-10,
265            max_entries_per_row: None,
266        }
267    }
268
269    /// Set sparsity threshold
270    pub fn with_threshold(mut self, threshold: f64) -> Result<Self> {
271        if threshold < 0.0 {
272            return Err(KernelError::InvalidParameter {
273                parameter: "threshold".to_string(),
274                value: threshold.to_string(),
275                reason: "must be non-negative".to_string(),
276            });
277        }
278        self.threshold = threshold;
279        Ok(self)
280    }
281
282    /// Set maximum entries per row
283    pub fn with_max_entries_per_row(mut self, max_entries: usize) -> Result<Self> {
284        if max_entries == 0 {
285            return Err(KernelError::InvalidParameter {
286                parameter: "max_entries_per_row".to_string(),
287                value: max_entries.to_string(),
288                reason: "must be positive".to_string(),
289            });
290        }
291        self.max_entries_per_row = Some(max_entries);
292        Ok(self)
293    }
294
295    /// Build sparse kernel matrix from data
296    pub fn build(&self, data: &[Vec<f64>], kernel: &dyn Kernel) -> Result<SparseKernelMatrix> {
297        let n = data.len();
298        let mut matrix = SparseKernelMatrix::new(n);
299
300        for i in 0..n {
301            let mut row_entries = Vec::new();
302
303            // Compute all values for this row
304            for j in 0..n {
305                let value = kernel.compute(&data[i], &data[j])?;
306                if value.abs() >= self.threshold {
307                    row_entries.push((j, value));
308                }
309            }
310
311            // If max_entries_per_row is set, keep only top-k entries
312            if let Some(max_entries) = self.max_entries_per_row {
313                if row_entries.len() > max_entries {
314                    // Sort by absolute value (descending)
315                    row_entries.sort_by(|(_, a), (_, b)| b.abs().partial_cmp(&a.abs()).unwrap());
316                    row_entries.truncate(max_entries);
317                }
318            }
319
320            // Add entries to matrix
321            for (j, value) in row_entries {
322                matrix.set(i, j, value);
323            }
324        }
325
326        matrix.finalize();
327        Ok(matrix)
328    }
329}
330
331impl Default for SparseKernelMatrixBuilder {
332    fn default() -> Self {
333        Self::new()
334    }
335}
336
337/// Advanced sparse matrix operations
338impl SparseKernelMatrix {
339    /// Sparse matrix-vector multiplication: y = A * x
340    pub fn spmv(&mut self, x: &[f64]) -> Result<Vec<f64>> {
341        if x.len() != self.size {
342            return Err(KernelError::InvalidParameter {
343                parameter: "x".to_string(),
344                value: x.len().to_string(),
345                reason: format!("vector length must match matrix size {}", self.size),
346            });
347        }
348
349        self.finalize();
350
351        let mut y = vec![0.0; self.size];
352
353        for (row, y_elem) in y.iter_mut().enumerate() {
354            let start = self.row_ptr[row];
355            let end = self.row_ptr[row + 1];
356
357            let mut sum = 0.0;
358            for i in start..end {
359                let col = self.col_idx[i];
360                let value = self.values[i];
361                sum += value * x[col];
362            }
363            *y_elem = sum;
364        }
365
366        Ok(y)
367    }
368
369    /// Sparse matrix transpose
370    pub fn transpose(&self) -> Result<Self> {
371        let mut transposed = Self::new(self.size);
372
373        for row in 0..self.size {
374            let start = self.row_ptr[row];
375            let end = self.row_ptr[row + 1];
376
377            for i in start..end {
378                let col = self.col_idx[i];
379                let value = self.values[i];
380                transposed.set(col, row, value);
381            }
382        }
383
384        transposed.finalize();
385        Ok(transposed)
386    }
387
388    /// Add two sparse matrices element-wise
389    pub fn add(&mut self, other: &Self) -> Result<Self> {
390        if self.size != other.size {
391            return Err(KernelError::InvalidParameter {
392                parameter: "other".to_string(),
393                value: other.size.to_string(),
394                reason: format!("matrix sizes must match: {} vs {}", self.size, other.size),
395            });
396        }
397
398        self.finalize();
399
400        // Clone and finalize other to ensure all values are in CSR format
401        let mut other_finalized = other.clone();
402        other_finalized.finalize();
403
404        let mut result = Self::new(self.size);
405
406        // Add values from self
407        for row in 0..self.size {
408            let start = self.row_ptr[row];
409            let end = self.row_ptr[row + 1];
410
411            for i in start..end {
412                let col = self.col_idx[i];
413                let value = self.values[i];
414                result.set(row, col, value);
415            }
416        }
417
418        // Add values from other
419        for row in 0..other_finalized.size {
420            let start = other_finalized.row_ptr[row];
421            let end = other_finalized.row_ptr[row + 1];
422
423            for i in start..end {
424                let col = other_finalized.col_idx[i];
425                let value = other_finalized.values[i];
426                let existing = result.get(row, col).unwrap_or(0.0);
427                result.set(row, col, existing + value);
428            }
429        }
430
431        result.finalize();
432        Ok(result)
433    }
434
435    /// Frobenius norm of the sparse matrix
436    pub fn frobenius_norm(&self) -> f64 {
437        let mut sum_squares = 0.0;
438
439        for row in 0..self.size {
440            let start = self.row_ptr[row];
441            let end = self.row_ptr[row + 1];
442
443            for i in start..end {
444                let value = self.values[i];
445                sum_squares += value * value;
446            }
447        }
448
449        sum_squares.sqrt()
450    }
451
452    /// Iterator over non-zero entries (row, col, value)
453    pub fn iter_nonzeros(&mut self) -> SparseMatrixIterator<'_> {
454        self.finalize();
455        SparseMatrixIterator {
456            matrix: self,
457            current_row: 0,
458            current_idx: 0,
459        }
460    }
461
462    /// Scale the matrix by a scalar
463    pub fn scale(&mut self, scalar: f64) {
464        for value in &mut self.values {
465            *value *= scalar;
466        }
467
468        for value in self.temp_map.values_mut() {
469            *value *= scalar;
470        }
471    }
472}
473
474/// Iterator for sparse matrix non-zero entries
475pub struct SparseMatrixIterator<'a> {
476    matrix: &'a SparseKernelMatrix,
477    current_row: usize,
478    current_idx: usize,
479}
480
481impl<'a> Iterator for SparseMatrixIterator<'a> {
482    type Item = (usize, usize, f64);
483
484    fn next(&mut self) -> Option<Self::Item> {
485        while self.current_row < self.matrix.size {
486            let row_end = self.matrix.row_ptr[self.current_row + 1];
487
488            if self.current_idx < row_end {
489                let col = self.matrix.col_idx[self.current_idx];
490                let value = self.matrix.values[self.current_idx];
491                self.current_idx += 1;
492                return Some((self.current_row, col, value));
493            }
494
495            self.current_row += 1;
496            self.current_idx = self
497                .matrix
498                .row_ptr
499                .get(self.current_row)
500                .copied()
501                .unwrap_or(0);
502        }
503
504        None
505    }
506}
507
508/// Parallel sparse kernel matrix builder
509impl SparseKernelMatrixBuilder {
510    /// Build sparse kernel matrix with parallel computation
511    pub fn build_parallel(
512        &self,
513        data: &[Vec<f64>],
514        kernel: &dyn Kernel,
515    ) -> Result<SparseKernelMatrix> {
516        use rayon::prelude::*;
517
518        let n = data.len();
519        let mut matrix = SparseKernelMatrix::new(n);
520
521        // Compute rows in parallel
522        let row_data: Vec<Vec<(usize, f64)>> = (0..n)
523            .into_par_iter()
524            .map(|i| {
525                let mut row_entries = Vec::new();
526
527                for j in 0..n {
528                    match kernel.compute(&data[i], &data[j]) {
529                        Ok(value) => {
530                            if value.abs() >= self.threshold {
531                                row_entries.push((j, value));
532                            }
533                        }
534                        Err(_) => continue,
535                    }
536                }
537
538                // If max_entries_per_row is set, keep only top-k entries
539                if let Some(max_entries) = self.max_entries_per_row {
540                    if row_entries.len() > max_entries {
541                        row_entries
542                            .sort_by(|(_, a), (_, b)| b.abs().partial_cmp(&a.abs()).unwrap());
543                        row_entries.truncate(max_entries);
544                    }
545                }
546
547                row_entries
548            })
549            .collect();
550
551        // Sequentially insert into matrix
552        for (i, row_entries) in row_data.into_iter().enumerate() {
553            for (j, value) in row_entries {
554                matrix.set(i, j, value);
555            }
556        }
557
558        matrix.finalize();
559        Ok(matrix)
560    }
561}
562
563#[cfg(test)]
564mod tests {
565    use super::*;
566    use crate::tensor_kernels::LinearKernel;
567
568    #[test]
569    fn test_sparse_matrix_creation() {
570        let matrix = SparseKernelMatrix::new(3);
571        assert_eq!(matrix.size(), 3);
572        assert_eq!(matrix.nnz(), 0);
573    }
574
575    #[test]
576    fn test_sparse_matrix_set_get() {
577        let mut matrix = SparseKernelMatrix::new(3);
578        matrix.set(0, 1, 0.8);
579        matrix.set(1, 2, 0.6);
580
581        assert_eq!(matrix.get(0, 1), Some(0.8));
582        assert_eq!(matrix.get(1, 2), Some(0.6));
583        assert_eq!(matrix.get(0, 2), None);
584    }
585
586    #[test]
587    fn test_sparse_matrix_finalize() {
588        let mut matrix = SparseKernelMatrix::new(3);
589        matrix.set(0, 1, 0.8);
590        matrix.set(1, 2, 0.6);
591        matrix.set(2, 0, 0.4);
592
593        matrix.finalize();
594
595        assert_eq!(matrix.get(0, 1), Some(0.8));
596        assert_eq!(matrix.get(1, 2), Some(0.6));
597        assert_eq!(matrix.get(2, 0), Some(0.4));
598    }
599
600    #[test]
601    fn test_sparse_matrix_nnz() {
602        let mut matrix = SparseKernelMatrix::new(3);
603        matrix.set(0, 1, 0.8);
604        matrix.set(1, 2, 0.6);
605
606        assert_eq!(matrix.nnz(), 2);
607    }
608
609    #[test]
610    fn test_sparse_matrix_density() {
611        let mut matrix = SparseKernelMatrix::new(3);
612        matrix.set(0, 1, 0.8);
613        matrix.set(1, 2, 0.6);
614
615        let density = matrix.density();
616        assert!((density - 2.0 / 9.0).abs() < 1e-10);
617    }
618
619    #[test]
620    fn test_sparse_matrix_to_dense() {
621        let mut matrix = SparseKernelMatrix::new(3);
622        matrix.set(0, 1, 0.8);
623        matrix.set(1, 2, 0.6);
624
625        let dense = matrix.to_dense();
626        assert_eq!(dense.len(), 3);
627        assert!((dense[0][1] - 0.8).abs() < 1e-10);
628        assert!((dense[1][2] - 0.6).abs() < 1e-10);
629        assert!(dense[0][0].abs() < 1e-10);
630    }
631
632    #[test]
633    fn test_sparse_matrix_from_kernel() {
634        let kernel = LinearKernel::new();
635        let data = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![0.5, 0.5]];
636
637        let mut matrix =
638            SparseKernelMatrix::from_kernel_with_threshold(&data, &kernel, 0.1).unwrap();
639
640        assert!(matrix.nnz() > 0);
641        let dense = matrix.to_dense();
642        assert_eq!(dense.len(), 3);
643    }
644
645    #[test]
646    fn test_sparse_matrix_row() {
647        let mut matrix = SparseKernelMatrix::new(3);
648        matrix.set(0, 1, 0.8);
649        matrix.set(0, 2, 0.6);
650
651        let row = matrix.row(0).unwrap();
652        assert_eq!(row.len(), 2);
653        assert!(row.contains(&(1, 0.8)));
654        assert!(row.contains(&(2, 0.6)));
655    }
656
657    #[test]
658    fn test_sparse_matrix_builder() {
659        let builder = SparseKernelMatrixBuilder::new();
660        let kernel = LinearKernel::new();
661        let data = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
662
663        let matrix = builder.build(&data, &kernel).unwrap();
664        assert!(matrix.nnz() > 0);
665    }
666
667    #[test]
668    fn test_sparse_matrix_builder_with_threshold() {
669        let builder = SparseKernelMatrixBuilder::new()
670            .with_threshold(0.5)
671            .unwrap();
672        let kernel = LinearKernel::new();
673        let data = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
674
675        let matrix = builder.build(&data, &kernel).unwrap();
676        assert!(matrix.nnz() > 0);
677    }
678
679    #[test]
680    fn test_sparse_matrix_builder_invalid_threshold() {
681        let result = SparseKernelMatrixBuilder::new().with_threshold(-0.1);
682        assert!(result.is_err());
683    }
684
685    #[test]
686    fn test_sparse_matrix_builder_max_entries() {
687        let builder = SparseKernelMatrixBuilder::new()
688            .with_max_entries_per_row(2)
689            .unwrap();
690        let kernel = LinearKernel::new();
691        let data = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![0.5, 0.5]];
692
693        let matrix = builder.build(&data, &kernel).unwrap();
694        // Each row should have at most 2 entries
695        for i in 0..matrix.size() {
696            let mut temp_matrix = matrix.clone();
697            let row = temp_matrix.row(i).unwrap();
698            assert!(row.len() <= 2);
699        }
700    }
701
702    #[test]
703    fn test_sparse_matrix_builder_invalid_max_entries() {
704        let result = SparseKernelMatrixBuilder::new().with_max_entries_per_row(0);
705        assert!(result.is_err());
706    }
707
708    #[test]
709    fn test_sparse_matrix_zero_threshold() {
710        let mut matrix = SparseKernelMatrix::new(3);
711        matrix.set(0, 1, 1e-11); // Very small value (below 1e-10 threshold)
712        matrix.finalize();
713
714        // Should be treated as zero and filtered out
715        assert_eq!(matrix.nnz(), 0);
716    }
717
718    #[test]
719    fn test_sparse_matrix_spmv() {
720        let mut matrix = SparseKernelMatrix::new(3);
721        matrix.set(0, 0, 2.0);
722        matrix.set(0, 2, 1.0);
723        matrix.set(1, 1, 3.0);
724        matrix.set(2, 0, 1.0);
725        matrix.set(2, 2, 2.0);
726
727        let x = vec![1.0, 2.0, 3.0];
728        let y = matrix.spmv(&x).unwrap();
729
730        assert_eq!(y.len(), 3);
731        assert!((y[0] - 5.0).abs() < 1e-10); // 2*1 + 1*3
732        assert!((y[1] - 6.0).abs() < 1e-10); // 3*2
733        assert!((y[2] - 7.0).abs() < 1e-10); // 1*1 + 2*3
734    }
735
736    #[test]
737    fn test_sparse_matrix_spmv_invalid_size() {
738        let mut matrix = SparseKernelMatrix::new(3);
739        matrix.set(0, 0, 1.0);
740
741        let x = vec![1.0, 2.0]; // Wrong size
742        let result = matrix.spmv(&x);
743        assert!(result.is_err());
744    }
745
746    #[test]
747    fn test_sparse_matrix_transpose() {
748        let mut matrix = SparseKernelMatrix::new(3);
749        matrix.set(0, 1, 0.8);
750        matrix.set(1, 2, 0.6);
751        matrix.set(2, 0, 0.4);
752        matrix.finalize();
753
754        let transposed = matrix.transpose().unwrap();
755
756        assert_eq!(transposed.get(1, 0), Some(0.8));
757        assert_eq!(transposed.get(2, 1), Some(0.6));
758        assert_eq!(transposed.get(0, 2), Some(0.4));
759    }
760
761    #[test]
762    fn test_sparse_matrix_add() {
763        let mut matrix1 = SparseKernelMatrix::new(3);
764        matrix1.set(0, 0, 1.0);
765        matrix1.set(0, 1, 2.0);
766        matrix1.set(1, 1, 3.0);
767
768        let mut matrix2 = SparseKernelMatrix::new(3);
769        matrix2.set(0, 1, 1.0);
770        matrix2.set(1, 2, 4.0);
771        matrix2.set(2, 2, 5.0);
772
773        let result = matrix1.add(&matrix2).unwrap();
774
775        assert_eq!(result.get(0, 0), Some(1.0));
776        assert_eq!(result.get(0, 1), Some(3.0)); // 2.0 + 1.0
777        assert_eq!(result.get(1, 1), Some(3.0));
778        assert_eq!(result.get(1, 2), Some(4.0));
779        assert_eq!(result.get(2, 2), Some(5.0));
780    }
781
782    #[test]
783    fn test_sparse_matrix_add_invalid_size() {
784        let mut matrix1 = SparseKernelMatrix::new(3);
785        matrix1.set(0, 0, 1.0);
786
787        let matrix2 = SparseKernelMatrix::new(2);
788        let result = matrix1.add(&matrix2);
789        assert!(result.is_err());
790    }
791
792    #[test]
793    fn test_sparse_matrix_frobenius_norm() {
794        let mut matrix = SparseKernelMatrix::new(3);
795        matrix.set(0, 0, 3.0);
796        matrix.set(1, 1, 4.0);
797        matrix.finalize();
798
799        let norm = matrix.frobenius_norm();
800        assert!((norm - 5.0).abs() < 1e-10); // sqrt(3^2 + 4^2) = 5
801    }
802
803    #[test]
804    fn test_sparse_matrix_iterator() {
805        let mut matrix = SparseKernelMatrix::new(3);
806        matrix.set(0, 1, 0.8);
807        matrix.set(1, 2, 0.6);
808        matrix.set(2, 0, 0.4);
809
810        let entries: Vec<_> = matrix.iter_nonzeros().collect();
811
812        assert_eq!(entries.len(), 3);
813        assert!(entries.contains(&(0, 1, 0.8)));
814        assert!(entries.contains(&(1, 2, 0.6)));
815        assert!(entries.contains(&(2, 0, 0.4)));
816    }
817
818    #[test]
819    fn test_sparse_matrix_scale() {
820        let mut matrix = SparseKernelMatrix::new(3);
821        matrix.set(0, 0, 2.0);
822        matrix.set(1, 1, 4.0);
823        matrix.finalize();
824
825        matrix.scale(0.5);
826
827        assert_eq!(matrix.get(0, 0), Some(1.0));
828        assert_eq!(matrix.get(1, 1), Some(2.0));
829    }
830
831    #[test]
832    fn test_sparse_matrix_builder_parallel() {
833        let builder = SparseKernelMatrixBuilder::new();
834        let kernel = LinearKernel::new();
835        let data = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![0.5, 0.5]];
836
837        let matrix = builder.build_parallel(&data, &kernel).unwrap();
838        assert!(matrix.nnz() > 0);
839
840        // Compare with sequential build
841        let matrix_seq = builder.build(&data, &kernel).unwrap();
842        assert_eq!(matrix.nnz(), matrix_seq.nnz());
843    }
844
845    #[test]
846    fn test_sparse_matrix_parallel_with_threshold() {
847        let builder = SparseKernelMatrixBuilder::new()
848            .with_threshold(0.5)
849            .unwrap();
850        let kernel = LinearKernel::new();
851        let data = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![0.5, 0.5]];
852
853        let matrix = builder.build_parallel(&data, &kernel).unwrap();
854        assert!(matrix.nnz() > 0);
855    }
856}