1use std::collections::HashMap;
43
44use serde::{Deserialize, Serialize};
45
46use crate::error::{KernelError, Result};
47use crate::types::Kernel;
48
49#[derive(Clone, Debug, Serialize, Deserialize)]
67pub struct SparseKernelMatrix {
68 size: usize,
70 row_ptr: Vec<usize>,
72 col_idx: Vec<usize>,
74 values: Vec<f64>,
76 #[serde(skip)]
78 temp_map: HashMap<(usize, usize), f64>,
79}
80
81impl SparseKernelMatrix {
82 pub fn new(size: usize) -> Self {
84 Self {
85 size,
86 row_ptr: vec![0; size + 1],
87 col_idx: Vec::new(),
88 values: Vec::new(),
89 temp_map: HashMap::new(),
90 }
91 }
92
93 pub fn set(&mut self, row: usize, col: usize, value: f64) {
95 if row >= self.size || col >= self.size {
96 return;
97 }
98
99 if value.abs() < 1e-10 {
100 self.temp_map.remove(&(row, col));
102 } else {
103 self.temp_map.insert((row, col), value);
104 }
105 }
106
107 pub fn get(&self, row: usize, col: usize) -> Option<f64> {
109 if row >= self.size || col >= self.size {
110 return None;
111 }
112
113 if let Some(&value) = self.temp_map.get(&(row, col)) {
115 return Some(value);
116 }
117
118 let start = self.row_ptr[row];
120 let end = self.row_ptr[row + 1];
121
122 for i in start..end {
123 if self.col_idx[i] == col {
124 return Some(self.values[i]);
125 }
126 }
127
128 None
129 }
130
131 pub fn finalize(&mut self) {
133 if self.temp_map.is_empty() {
134 return;
135 }
136
137 self.col_idx.clear();
139 self.values.clear();
140 self.row_ptr = vec![0; self.size + 1];
141
142 let mut entries: Vec<_> = self.temp_map.iter().collect();
144 entries.sort_by_key(|&((row, col), _)| (*row, *col));
145
146 let mut current_row = 0;
148 for (&(row, col), &value) in &entries {
149 while current_row < row {
151 current_row += 1;
152 self.row_ptr[current_row] = self.col_idx.len();
153 }
154
155 self.col_idx.push(col);
156 self.values.push(value);
157 }
158
159 while current_row < self.size {
161 current_row += 1;
162 self.row_ptr[current_row] = self.col_idx.len();
163 }
164
165 self.temp_map.clear();
167 }
168
169 pub fn nnz(&self) -> usize {
171 self.values.len() + self.temp_map.len()
172 }
173
174 pub fn size(&self) -> usize {
176 self.size
177 }
178
179 pub fn density(&self) -> f64 {
181 let total = self.size * self.size;
182 if total == 0 {
183 0.0
184 } else {
185 self.nnz() as f64 / total as f64
186 }
187 }
188
189 #[allow(clippy::needless_range_loop)]
191 pub fn to_dense(&mut self) -> Vec<Vec<f64>> {
192 self.finalize();
193
194 let mut dense = vec![vec![0.0; self.size]; self.size];
195
196 for row in 0..self.size {
197 let start = self.row_ptr[row];
198 let end = self.row_ptr[row + 1];
199
200 for i in start..end {
201 let col = self.col_idx[i];
202 let value = self.values[i];
203 dense[row][col] = value;
204 }
205 }
206
207 dense
208 }
209
210 pub fn from_kernel_with_threshold(
212 data: &[Vec<f64>],
213 kernel: &dyn Kernel,
214 threshold: f64,
215 ) -> Result<Self> {
216 let n = data.len();
217 let mut matrix = Self::new(n);
218
219 for i in 0..n {
220 for j in 0..n {
221 let value = kernel.compute(&data[i], &data[j])?;
222 if value.abs() >= threshold {
223 matrix.set(i, j, value);
224 }
225 }
226 }
227
228 matrix.finalize();
229 Ok(matrix)
230 }
231
232 pub fn row(&mut self, row_idx: usize) -> Option<Vec<(usize, f64)>> {
234 if row_idx >= self.size {
235 return None;
236 }
237
238 self.finalize();
239
240 let start = self.row_ptr[row_idx];
241 let end = self.row_ptr[row_idx + 1];
242
243 let mut row_data = Vec::new();
244 for i in start..end {
245 row_data.push((self.col_idx[i], self.values[i]));
246 }
247
248 Some(row_data)
249 }
250}
251
252pub struct SparseKernelMatrixBuilder {
254 threshold: f64,
256 max_entries_per_row: Option<usize>,
258}
259
260impl SparseKernelMatrixBuilder {
261 pub fn new() -> Self {
263 Self {
264 threshold: 1e-10,
265 max_entries_per_row: None,
266 }
267 }
268
269 pub fn with_threshold(mut self, threshold: f64) -> Result<Self> {
271 if threshold < 0.0 {
272 return Err(KernelError::InvalidParameter {
273 parameter: "threshold".to_string(),
274 value: threshold.to_string(),
275 reason: "must be non-negative".to_string(),
276 });
277 }
278 self.threshold = threshold;
279 Ok(self)
280 }
281
282 pub fn with_max_entries_per_row(mut self, max_entries: usize) -> Result<Self> {
284 if max_entries == 0 {
285 return Err(KernelError::InvalidParameter {
286 parameter: "max_entries_per_row".to_string(),
287 value: max_entries.to_string(),
288 reason: "must be positive".to_string(),
289 });
290 }
291 self.max_entries_per_row = Some(max_entries);
292 Ok(self)
293 }
294
295 pub fn build(&self, data: &[Vec<f64>], kernel: &dyn Kernel) -> Result<SparseKernelMatrix> {
297 let n = data.len();
298 let mut matrix = SparseKernelMatrix::new(n);
299
300 for i in 0..n {
301 let mut row_entries = Vec::new();
302
303 for j in 0..n {
305 let value = kernel.compute(&data[i], &data[j])?;
306 if value.abs() >= self.threshold {
307 row_entries.push((j, value));
308 }
309 }
310
311 if let Some(max_entries) = self.max_entries_per_row {
313 if row_entries.len() > max_entries {
314 row_entries.sort_by(|(_, a), (_, b)| b.abs().partial_cmp(&a.abs()).unwrap());
316 row_entries.truncate(max_entries);
317 }
318 }
319
320 for (j, value) in row_entries {
322 matrix.set(i, j, value);
323 }
324 }
325
326 matrix.finalize();
327 Ok(matrix)
328 }
329}
330
331impl Default for SparseKernelMatrixBuilder {
332 fn default() -> Self {
333 Self::new()
334 }
335}
336
337impl SparseKernelMatrix {
339 pub fn spmv(&mut self, x: &[f64]) -> Result<Vec<f64>> {
341 if x.len() != self.size {
342 return Err(KernelError::InvalidParameter {
343 parameter: "x".to_string(),
344 value: x.len().to_string(),
345 reason: format!("vector length must match matrix size {}", self.size),
346 });
347 }
348
349 self.finalize();
350
351 let mut y = vec![0.0; self.size];
352
353 for (row, y_elem) in y.iter_mut().enumerate() {
354 let start = self.row_ptr[row];
355 let end = self.row_ptr[row + 1];
356
357 let mut sum = 0.0;
358 for i in start..end {
359 let col = self.col_idx[i];
360 let value = self.values[i];
361 sum += value * x[col];
362 }
363 *y_elem = sum;
364 }
365
366 Ok(y)
367 }
368
369 pub fn transpose(&self) -> Result<Self> {
371 let mut transposed = Self::new(self.size);
372
373 for row in 0..self.size {
374 let start = self.row_ptr[row];
375 let end = self.row_ptr[row + 1];
376
377 for i in start..end {
378 let col = self.col_idx[i];
379 let value = self.values[i];
380 transposed.set(col, row, value);
381 }
382 }
383
384 transposed.finalize();
385 Ok(transposed)
386 }
387
388 pub fn add(&mut self, other: &Self) -> Result<Self> {
390 if self.size != other.size {
391 return Err(KernelError::InvalidParameter {
392 parameter: "other".to_string(),
393 value: other.size.to_string(),
394 reason: format!("matrix sizes must match: {} vs {}", self.size, other.size),
395 });
396 }
397
398 self.finalize();
399
400 let mut other_finalized = other.clone();
402 other_finalized.finalize();
403
404 let mut result = Self::new(self.size);
405
406 for row in 0..self.size {
408 let start = self.row_ptr[row];
409 let end = self.row_ptr[row + 1];
410
411 for i in start..end {
412 let col = self.col_idx[i];
413 let value = self.values[i];
414 result.set(row, col, value);
415 }
416 }
417
418 for row in 0..other_finalized.size {
420 let start = other_finalized.row_ptr[row];
421 let end = other_finalized.row_ptr[row + 1];
422
423 for i in start..end {
424 let col = other_finalized.col_idx[i];
425 let value = other_finalized.values[i];
426 let existing = result.get(row, col).unwrap_or(0.0);
427 result.set(row, col, existing + value);
428 }
429 }
430
431 result.finalize();
432 Ok(result)
433 }
434
435 pub fn frobenius_norm(&self) -> f64 {
437 let mut sum_squares = 0.0;
438
439 for row in 0..self.size {
440 let start = self.row_ptr[row];
441 let end = self.row_ptr[row + 1];
442
443 for i in start..end {
444 let value = self.values[i];
445 sum_squares += value * value;
446 }
447 }
448
449 sum_squares.sqrt()
450 }
451
452 pub fn iter_nonzeros(&mut self) -> SparseMatrixIterator<'_> {
454 self.finalize();
455 SparseMatrixIterator {
456 matrix: self,
457 current_row: 0,
458 current_idx: 0,
459 }
460 }
461
462 pub fn scale(&mut self, scalar: f64) {
464 for value in &mut self.values {
465 *value *= scalar;
466 }
467
468 for value in self.temp_map.values_mut() {
469 *value *= scalar;
470 }
471 }
472}
473
474pub struct SparseMatrixIterator<'a> {
476 matrix: &'a SparseKernelMatrix,
477 current_row: usize,
478 current_idx: usize,
479}
480
481impl<'a> Iterator for SparseMatrixIterator<'a> {
482 type Item = (usize, usize, f64);
483
484 fn next(&mut self) -> Option<Self::Item> {
485 while self.current_row < self.matrix.size {
486 let row_end = self.matrix.row_ptr[self.current_row + 1];
487
488 if self.current_idx < row_end {
489 let col = self.matrix.col_idx[self.current_idx];
490 let value = self.matrix.values[self.current_idx];
491 self.current_idx += 1;
492 return Some((self.current_row, col, value));
493 }
494
495 self.current_row += 1;
496 self.current_idx = self
497 .matrix
498 .row_ptr
499 .get(self.current_row)
500 .copied()
501 .unwrap_or(0);
502 }
503
504 None
505 }
506}
507
508impl SparseKernelMatrixBuilder {
510 pub fn build_parallel(
512 &self,
513 data: &[Vec<f64>],
514 kernel: &dyn Kernel,
515 ) -> Result<SparseKernelMatrix> {
516 use rayon::prelude::*;
517
518 let n = data.len();
519 let mut matrix = SparseKernelMatrix::new(n);
520
521 let row_data: Vec<Vec<(usize, f64)>> = (0..n)
523 .into_par_iter()
524 .map(|i| {
525 let mut row_entries = Vec::new();
526
527 for j in 0..n {
528 match kernel.compute(&data[i], &data[j]) {
529 Ok(value) => {
530 if value.abs() >= self.threshold {
531 row_entries.push((j, value));
532 }
533 }
534 Err(_) => continue,
535 }
536 }
537
538 if let Some(max_entries) = self.max_entries_per_row {
540 if row_entries.len() > max_entries {
541 row_entries
542 .sort_by(|(_, a), (_, b)| b.abs().partial_cmp(&a.abs()).unwrap());
543 row_entries.truncate(max_entries);
544 }
545 }
546
547 row_entries
548 })
549 .collect();
550
551 for (i, row_entries) in row_data.into_iter().enumerate() {
553 for (j, value) in row_entries {
554 matrix.set(i, j, value);
555 }
556 }
557
558 matrix.finalize();
559 Ok(matrix)
560 }
561}
562
563#[cfg(test)]
564mod tests {
565 use super::*;
566 use crate::tensor_kernels::LinearKernel;
567
568 #[test]
569 fn test_sparse_matrix_creation() {
570 let matrix = SparseKernelMatrix::new(3);
571 assert_eq!(matrix.size(), 3);
572 assert_eq!(matrix.nnz(), 0);
573 }
574
575 #[test]
576 fn test_sparse_matrix_set_get() {
577 let mut matrix = SparseKernelMatrix::new(3);
578 matrix.set(0, 1, 0.8);
579 matrix.set(1, 2, 0.6);
580
581 assert_eq!(matrix.get(0, 1), Some(0.8));
582 assert_eq!(matrix.get(1, 2), Some(0.6));
583 assert_eq!(matrix.get(0, 2), None);
584 }
585
586 #[test]
587 fn test_sparse_matrix_finalize() {
588 let mut matrix = SparseKernelMatrix::new(3);
589 matrix.set(0, 1, 0.8);
590 matrix.set(1, 2, 0.6);
591 matrix.set(2, 0, 0.4);
592
593 matrix.finalize();
594
595 assert_eq!(matrix.get(0, 1), Some(0.8));
596 assert_eq!(matrix.get(1, 2), Some(0.6));
597 assert_eq!(matrix.get(2, 0), Some(0.4));
598 }
599
600 #[test]
601 fn test_sparse_matrix_nnz() {
602 let mut matrix = SparseKernelMatrix::new(3);
603 matrix.set(0, 1, 0.8);
604 matrix.set(1, 2, 0.6);
605
606 assert_eq!(matrix.nnz(), 2);
607 }
608
609 #[test]
610 fn test_sparse_matrix_density() {
611 let mut matrix = SparseKernelMatrix::new(3);
612 matrix.set(0, 1, 0.8);
613 matrix.set(1, 2, 0.6);
614
615 let density = matrix.density();
616 assert!((density - 2.0 / 9.0).abs() < 1e-10);
617 }
618
619 #[test]
620 fn test_sparse_matrix_to_dense() {
621 let mut matrix = SparseKernelMatrix::new(3);
622 matrix.set(0, 1, 0.8);
623 matrix.set(1, 2, 0.6);
624
625 let dense = matrix.to_dense();
626 assert_eq!(dense.len(), 3);
627 assert!((dense[0][1] - 0.8).abs() < 1e-10);
628 assert!((dense[1][2] - 0.6).abs() < 1e-10);
629 assert!(dense[0][0].abs() < 1e-10);
630 }
631
632 #[test]
633 fn test_sparse_matrix_from_kernel() {
634 let kernel = LinearKernel::new();
635 let data = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![0.5, 0.5]];
636
637 let mut matrix =
638 SparseKernelMatrix::from_kernel_with_threshold(&data, &kernel, 0.1).unwrap();
639
640 assert!(matrix.nnz() > 0);
641 let dense = matrix.to_dense();
642 assert_eq!(dense.len(), 3);
643 }
644
645 #[test]
646 fn test_sparse_matrix_row() {
647 let mut matrix = SparseKernelMatrix::new(3);
648 matrix.set(0, 1, 0.8);
649 matrix.set(0, 2, 0.6);
650
651 let row = matrix.row(0).unwrap();
652 assert_eq!(row.len(), 2);
653 assert!(row.contains(&(1, 0.8)));
654 assert!(row.contains(&(2, 0.6)));
655 }
656
657 #[test]
658 fn test_sparse_matrix_builder() {
659 let builder = SparseKernelMatrixBuilder::new();
660 let kernel = LinearKernel::new();
661 let data = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
662
663 let matrix = builder.build(&data, &kernel).unwrap();
664 assert!(matrix.nnz() > 0);
665 }
666
667 #[test]
668 fn test_sparse_matrix_builder_with_threshold() {
669 let builder = SparseKernelMatrixBuilder::new()
670 .with_threshold(0.5)
671 .unwrap();
672 let kernel = LinearKernel::new();
673 let data = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
674
675 let matrix = builder.build(&data, &kernel).unwrap();
676 assert!(matrix.nnz() > 0);
677 }
678
679 #[test]
680 fn test_sparse_matrix_builder_invalid_threshold() {
681 let result = SparseKernelMatrixBuilder::new().with_threshold(-0.1);
682 assert!(result.is_err());
683 }
684
685 #[test]
686 fn test_sparse_matrix_builder_max_entries() {
687 let builder = SparseKernelMatrixBuilder::new()
688 .with_max_entries_per_row(2)
689 .unwrap();
690 let kernel = LinearKernel::new();
691 let data = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![0.5, 0.5]];
692
693 let matrix = builder.build(&data, &kernel).unwrap();
694 for i in 0..matrix.size() {
696 let mut temp_matrix = matrix.clone();
697 let row = temp_matrix.row(i).unwrap();
698 assert!(row.len() <= 2);
699 }
700 }
701
702 #[test]
703 fn test_sparse_matrix_builder_invalid_max_entries() {
704 let result = SparseKernelMatrixBuilder::new().with_max_entries_per_row(0);
705 assert!(result.is_err());
706 }
707
708 #[test]
709 fn test_sparse_matrix_zero_threshold() {
710 let mut matrix = SparseKernelMatrix::new(3);
711 matrix.set(0, 1, 1e-11); matrix.finalize();
713
714 assert_eq!(matrix.nnz(), 0);
716 }
717
718 #[test]
719 fn test_sparse_matrix_spmv() {
720 let mut matrix = SparseKernelMatrix::new(3);
721 matrix.set(0, 0, 2.0);
722 matrix.set(0, 2, 1.0);
723 matrix.set(1, 1, 3.0);
724 matrix.set(2, 0, 1.0);
725 matrix.set(2, 2, 2.0);
726
727 let x = vec![1.0, 2.0, 3.0];
728 let y = matrix.spmv(&x).unwrap();
729
730 assert_eq!(y.len(), 3);
731 assert!((y[0] - 5.0).abs() < 1e-10); assert!((y[1] - 6.0).abs() < 1e-10); assert!((y[2] - 7.0).abs() < 1e-10); }
735
736 #[test]
737 fn test_sparse_matrix_spmv_invalid_size() {
738 let mut matrix = SparseKernelMatrix::new(3);
739 matrix.set(0, 0, 1.0);
740
741 let x = vec![1.0, 2.0]; let result = matrix.spmv(&x);
743 assert!(result.is_err());
744 }
745
746 #[test]
747 fn test_sparse_matrix_transpose() {
748 let mut matrix = SparseKernelMatrix::new(3);
749 matrix.set(0, 1, 0.8);
750 matrix.set(1, 2, 0.6);
751 matrix.set(2, 0, 0.4);
752 matrix.finalize();
753
754 let transposed = matrix.transpose().unwrap();
755
756 assert_eq!(transposed.get(1, 0), Some(0.8));
757 assert_eq!(transposed.get(2, 1), Some(0.6));
758 assert_eq!(transposed.get(0, 2), Some(0.4));
759 }
760
761 #[test]
762 fn test_sparse_matrix_add() {
763 let mut matrix1 = SparseKernelMatrix::new(3);
764 matrix1.set(0, 0, 1.0);
765 matrix1.set(0, 1, 2.0);
766 matrix1.set(1, 1, 3.0);
767
768 let mut matrix2 = SparseKernelMatrix::new(3);
769 matrix2.set(0, 1, 1.0);
770 matrix2.set(1, 2, 4.0);
771 matrix2.set(2, 2, 5.0);
772
773 let result = matrix1.add(&matrix2).unwrap();
774
775 assert_eq!(result.get(0, 0), Some(1.0));
776 assert_eq!(result.get(0, 1), Some(3.0)); assert_eq!(result.get(1, 1), Some(3.0));
778 assert_eq!(result.get(1, 2), Some(4.0));
779 assert_eq!(result.get(2, 2), Some(5.0));
780 }
781
782 #[test]
783 fn test_sparse_matrix_add_invalid_size() {
784 let mut matrix1 = SparseKernelMatrix::new(3);
785 matrix1.set(0, 0, 1.0);
786
787 let matrix2 = SparseKernelMatrix::new(2);
788 let result = matrix1.add(&matrix2);
789 assert!(result.is_err());
790 }
791
792 #[test]
793 fn test_sparse_matrix_frobenius_norm() {
794 let mut matrix = SparseKernelMatrix::new(3);
795 matrix.set(0, 0, 3.0);
796 matrix.set(1, 1, 4.0);
797 matrix.finalize();
798
799 let norm = matrix.frobenius_norm();
800 assert!((norm - 5.0).abs() < 1e-10); }
802
803 #[test]
804 fn test_sparse_matrix_iterator() {
805 let mut matrix = SparseKernelMatrix::new(3);
806 matrix.set(0, 1, 0.8);
807 matrix.set(1, 2, 0.6);
808 matrix.set(2, 0, 0.4);
809
810 let entries: Vec<_> = matrix.iter_nonzeros().collect();
811
812 assert_eq!(entries.len(), 3);
813 assert!(entries.contains(&(0, 1, 0.8)));
814 assert!(entries.contains(&(1, 2, 0.6)));
815 assert!(entries.contains(&(2, 0, 0.4)));
816 }
817
818 #[test]
819 fn test_sparse_matrix_scale() {
820 let mut matrix = SparseKernelMatrix::new(3);
821 matrix.set(0, 0, 2.0);
822 matrix.set(1, 1, 4.0);
823 matrix.finalize();
824
825 matrix.scale(0.5);
826
827 assert_eq!(matrix.get(0, 0), Some(1.0));
828 assert_eq!(matrix.get(1, 1), Some(2.0));
829 }
830
831 #[test]
832 fn test_sparse_matrix_builder_parallel() {
833 let builder = SparseKernelMatrixBuilder::new();
834 let kernel = LinearKernel::new();
835 let data = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![0.5, 0.5]];
836
837 let matrix = builder.build_parallel(&data, &kernel).unwrap();
838 assert!(matrix.nnz() > 0);
839
840 let matrix_seq = builder.build(&data, &kernel).unwrap();
842 assert_eq!(matrix.nnz(), matrix_seq.nnz());
843 }
844
845 #[test]
846 fn test_sparse_matrix_parallel_with_threshold() {
847 let builder = SparseKernelMatrixBuilder::new()
848 .with_threshold(0.5)
849 .unwrap();
850 let kernel = LinearKernel::new();
851 let data = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![0.5, 0.5]];
852
853 let matrix = builder.build_parallel(&data, &kernel).unwrap();
854 assert!(matrix.nnz() > 0);
855 }
856}