1use crate::{determine_chunk_size, SMat, SvdFloat};
2use nalgebra_sparse::na::{DMatrix, DVector};
3use nalgebra_sparse::CsrMatrix;
4use num_traits::Float;
5use rayon::iter::IndexedParallelIterator;
6use rayon::iter::ParallelIterator;
7use rayon::prelude::{
8 IntoParallelIterator, IntoParallelRefIterator, ParallelBridge, ParallelSliceMut,
9};
10use std::fmt::Debug;
11use std::ops::AddAssign;
12
13pub struct MaskedCSRMatrix<'a, T: Float> {
14 matrix: &'a CsrMatrix<T>,
15 column_mask: Vec<bool>,
16 masked_to_original: Vec<usize>,
17 original_to_masked: Vec<Option<usize>>,
18}
19
20impl<'a, T: Float> MaskedCSRMatrix<'a, T> {
21 pub fn new(matrix: &'a CsrMatrix<T>, column_mask: Vec<bool>) -> Self {
22 assert_eq!(
23 column_mask.len(),
24 matrix.ncols(),
25 "Column mask must have the same length as the number of columns in the matrix"
26 );
27
28 let mut masked_to_original = Vec::new();
29 let mut original_to_masked = vec![None; column_mask.len()];
30 let mut masked_index = 0;
31
32 for (i, &is_included) in column_mask.iter().enumerate() {
33 if is_included {
34 masked_to_original.push(i);
35 original_to_masked[i] = Some(masked_index);
36 masked_index += 1;
37 }
38 }
39
40 Self {
41 matrix,
42 column_mask,
43 masked_to_original,
44 original_to_masked,
45 }
46 }
47
48 pub fn with_columns(matrix: &'a CsrMatrix<T>, columns: &[usize]) -> Self {
49 let mut mask = vec![false; matrix.ncols()];
50 for &col in columns {
51 assert!(col < matrix.ncols(), "Column index out of bounds");
52 mask[col] = true;
53 }
54 Self::new(matrix, mask)
55 }
56
57 pub fn uses_all_columns(&self) -> bool {
58 self.masked_to_original.len() == self.matrix.ncols() && self.column_mask.iter().all(|&x| x)
59 }
60
61 pub fn ensure_identical_results_mode(&self) -> bool {
62 let is_small_matrix = self.matrix.nrows() <= 5 && self.matrix.ncols() <= 5;
64 is_small_matrix && self.uses_all_columns()
65 }
66}
67
68impl<
69 T: Float
70 + AddAssign
71 + Sync
72 + Send
73 + std::ops::MulAssign
74 + Debug
75 + 'static
76 + std::iter::Sum
77 + std::ops::SubAssign
78 + num_traits::FromPrimitive,
79 > SMat<T> for MaskedCSRMatrix<'_, T>
80{
81 fn nrows(&self) -> usize {
82 self.matrix.nrows()
83 }
84
85 fn ncols(&self) -> usize {
86 self.masked_to_original.len()
87 }
88
89 fn nnz(&self) -> usize {
90 let (major_offsets, minor_indices, _) = self.matrix.csr_data();
91 let mut count = 0;
92
93 for i in 0..self.matrix.nrows() {
94 for j in major_offsets[i]..major_offsets[i + 1] {
95 let col = minor_indices[j];
96 if self.column_mask[col] {
97 count += 1;
98 }
99 }
100 }
101 count
102 }
103
104 fn svd_opa(&self, x: &[T], y: &mut [T], transposed: bool) {
105 let nrows = if transposed {
106 self.ncols()
107 } else {
108 self.nrows()
109 };
110 let ncols = if transposed {
111 self.nrows()
112 } else {
113 self.ncols()
114 };
115
116 assert_eq!(
117 x.len(),
118 ncols,
119 "svd_opa: x must be A.ncols() in length, x = {}, A.ncols = {}",
120 x.len(),
121 ncols
122 );
123 assert_eq!(
124 y.len(),
125 nrows,
126 "svd_opa: y must be A.nrows() in length, y = {}, A.nrows = {}",
127 y.len(),
128 nrows
129 );
130
131 let (major_offsets, minor_indices, values) = self.matrix.csr_data();
132
133 if self.uses_all_columns() || (self.matrix.nrows() < 1000 && self.matrix.ncols() < 1000) {
134 if !transposed {
136 self.matrix.svd_opa(x, y, false);
138 } else {
139 self.matrix.svd_opa(x, y, true);
141 }
142 return;
143 }
144
145 y.fill(T::zero());
146
147 if !transposed {
148 let valid_indices: Vec<Option<usize>> = (0..self.matrix.ncols())
150 .map(|col| self.original_to_masked[col])
151 .collect();
152
153 let rows = self.matrix.nrows();
155 let chunk_size = std::cmp::max(16, rows / (rayon::current_num_threads() * 2));
156
157 y.par_chunks_mut(chunk_size)
159 .enumerate()
160 .for_each(|(chunk_idx, y_chunk)| {
161 let start_row = chunk_idx * chunk_size;
162 let end_row = (start_row + y_chunk.len()).min(rows);
163
164 for i in start_row..end_row {
165 let row_idx = i - start_row;
166 let mut sum = T::zero();
167
168 let row_start = major_offsets[i];
170 let row_end = major_offsets[i + 1];
171
172 let mut j = row_start;
174 while j + 4 <= row_end {
175 for offset in 0..4 {
176 let idx = j + offset;
177 let col = minor_indices[idx];
178 if let Some(masked_col) = valid_indices[col] {
179 sum += values[idx] * x[masked_col];
180 }
181 }
182 j += 4;
183 }
184
185 while j < row_end {
187 let col = minor_indices[j];
188 if let Some(masked_col) = valid_indices[col] {
189 sum += values[j] * x[masked_col];
190 }
191 j += 1;
192 }
193
194 y_chunk[row_idx] = sum;
195 }
196 });
197 } else {
198 let nrows = self.matrix.nrows();
200 let chunk_size = crate::utils::determine_chunk_size(nrows);
201
202 let results: Vec<Vec<T>> = (0..nrows.div_ceil(chunk_size))
204 .into_par_iter()
205 .map(|chunk_idx| {
206 let start = chunk_idx * chunk_size;
207 let end = (start + chunk_size).min(nrows);
208 let mut local_y = vec![T::zero(); y.len()];
209
210 for i in start..end {
212 let row_val = x[i];
213 if row_val.is_zero() {
214 continue; }
216
217 for j in major_offsets[i]..major_offsets[i + 1] {
218 let col = minor_indices[j];
219 if let Some(masked_col) = self.original_to_masked[col] {
220 local_y[masked_col] += values[j] * row_val;
221 }
222 }
223 }
224 local_y
225 })
226 .collect();
227
228 for local_y in results {
230 for (idx, &val) in local_y.iter().enumerate() {
232 if !val.is_zero() {
233 y[idx] += val;
234 }
235 }
236 }
237 }
238 }
239
240 fn compute_column_means(&self) -> Vec<T> {
241 let rows = self.nrows();
242 let masked_cols = self.ncols();
243 let row_count_recip = T::one() / T::from(rows).unwrap();
244
245 let mut col_sums = vec![T::zero(); masked_cols];
246 let (row_offsets, col_indices, values) = self.matrix.csr_data();
247
248 for i in 0..rows {
249 for j in row_offsets[i]..row_offsets[i + 1] {
250 let original_col = col_indices[j];
251 if let Some(masked_col) = self.original_to_masked[original_col] {
252 col_sums[masked_col] += values[j];
253 }
254 }
255 }
256
257 for j in 0..masked_cols {
259 col_sums[j] *= row_count_recip;
260 }
261
262 col_sums
263 }
264
265 fn multiply_with_dense(
266 &self,
267 dense: &DMatrix<T>,
268 result: &mut DMatrix<T>,
269 transpose_self: bool,
270 ) {
271 let m_rows = if transpose_self {
272 self.ncols()
273 } else {
274 self.nrows()
275 };
276 let m_cols = if transpose_self {
277 self.nrows()
278 } else {
279 self.ncols()
280 };
281
282 assert_eq!(
283 dense.nrows(),
284 m_cols,
285 "Dense matrix has incompatible row count"
286 );
287 assert_eq!(
288 result.nrows(),
289 m_rows,
290 "Result matrix has incompatible row count"
291 );
292 assert_eq!(
293 result.ncols(),
294 dense.ncols(),
295 "Result matrix has incompatible column count"
296 );
297
298 let (major_offsets, minor_indices, values) = self.matrix.csr_data();
299
300 if !transpose_self {
301 let rows = self.matrix.nrows();
302 let dense_cols = dense.ncols();
303
304 let valid_cols: Vec<Option<usize>> = (0..self.matrix.ncols())
306 .map(|col| self.original_to_masked.get(col).copied().flatten())
307 .collect();
308
309 let row_results: Vec<(usize, Vec<T>)> = (0..rows)
311 .into_par_iter()
312 .map(|row| {
313 let mut row_result = vec![T::zero(); dense_cols];
314
315 let row_start = major_offsets[row];
317 let row_end = major_offsets[row + 1];
318
319 let mut j = row_start;
321 while j + 4 <= row_end {
322 for offset in 0..4 {
324 let idx = j + offset;
325 let col = minor_indices[idx];
326 if let Some(masked_col) = valid_cols[col] {
327 let val = values[idx];
328
329 for c in 0..dense_cols {
331 row_result[c] += val * dense[(masked_col, c)];
332 }
333 }
334 }
335 j += 4;
336 }
337
338 while j < row_end {
340 let col = minor_indices[j];
341 if let Some(masked_col) = valid_cols[col] {
342 let val = values[j];
343
344 for c in 0..dense_cols {
345 row_result[c] += val * dense[(masked_col, c)];
346 }
347 }
348 j += 1;
349 }
350
351 (row, row_result)
352 })
353 .collect();
354
355 for (row, row_values) in row_results {
357 for c in 0..dense_cols {
358 result[(row, c)] = row_values[c];
359 }
360 }
361 } else {
362 let nrows = self.matrix.nrows();
363 let ncols = self.ncols();
364 let dense_cols = dense.ncols();
365
366 result.fill(T::zero());
368
369 let valid_cols: Vec<Option<usize>> = (0..self.matrix.ncols())
371 .map(|col| self.original_to_masked.get(col).copied().flatten())
372 .collect();
373
374 let chunk_size = determine_chunk_size(nrows);
375
376 let partial_results: Vec<Vec<T>> = (0..nrows.div_ceil(chunk_size))
378 .into_par_iter()
379 .map(|chunk_idx| {
380 let start = chunk_idx * chunk_size;
381 let end = (start + chunk_size).min(nrows);
382
383 let mut local_result = vec![T::zero(); ncols * dense_cols];
385
386 for i in start..end {
388 let dense_row = unsafe {
389 std::slice::from_raw_parts(
390 dense.as_ptr().add(i * dense_cols),
391 dense_cols,
392 )
393 };
394
395 let row_start = major_offsets[i];
397 let row_end = major_offsets[i + 1];
398
399 let mut j = row_start;
401 while j + 8 <= row_end {
402 for offset in 0..8 {
403 let idx = j + offset;
404 let col = minor_indices[idx];
405 if let Some(masked_col) = valid_cols[col] {
406 let val = values[idx];
407 let base_offset = masked_col * dense_cols;
408
409 let mut c = 0;
411 while c + 4 <= dense_cols {
412 local_result[base_offset + c] += val * dense_row[c];
413 local_result[base_offset + c + 1] += val * dense_row[c + 1];
414 local_result[base_offset + c + 2] += val * dense_row[c + 2];
415 local_result[base_offset + c + 3] += val * dense_row[c + 3];
416 c += 4;
417 }
418
419 while c < dense_cols {
421 local_result[base_offset + c] += val * dense_row[c];
422 c += 1;
423 }
424 }
425 }
426 j += 8;
427 }
428
429 while j < row_end {
431 let col = minor_indices[j];
432 if let Some(masked_col) = valid_cols[col] {
433 let val = values[j];
434 let base_offset = masked_col * dense_cols;
435
436 for c in 0..dense_cols {
437 local_result[base_offset + c] += val * dense_row[c];
438 }
439 }
440 j += 1;
441 }
442 }
443
444 local_result
445 })
446 .collect();
447
448 const BLOCK_SIZE: usize = 64;
450 for local_result in partial_results {
451 for r_block in (0..ncols).step_by(BLOCK_SIZE) {
453 let r_end = (r_block + BLOCK_SIZE).min(ncols);
454
455 for c_block in (0..dense_cols).step_by(BLOCK_SIZE) {
456 let c_end = (c_block + BLOCK_SIZE).min(dense_cols);
457
458 for r in r_block..r_end {
460 for c in c_block..c_end {
461 let val = local_result[r * dense_cols + c];
462 if !val.is_zero() {
463 result[(r, c)] += val;
464 }
465 }
466 }
467 }
468 }
469 }
470 }
471 }
472
473 fn multiply_with_dense_centered(
474 &self,
475 dense: &DMatrix<T>,
476 result: &mut DMatrix<T>,
477 transpose_self: bool,
478 means: &DVector<T>,
479 ) {
480 let (major_offsets, minor_indices, values) = self.matrix.csr_data();
481
482 let dense_cols = dense.ncols();
484 let dense_rows = dense.nrows();
485
486 let col_sums: Vec<T> = (0..dense_cols)
488 .into_par_iter()
489 .map(|c| (0..dense_rows).map(|i| dense[(i, c)]).sum())
490 .collect();
491
492 if !transpose_self {
493 let rows = self.matrix.nrows();
494
495 let mean_adjustments: Vec<T> = col_sums
497 .iter()
498 .map(|&col_sum| {
499 means
500 .iter()
501 .enumerate()
502 .filter_map(|(original_idx, &mean_val)| {
503 self.original_to_masked
504 .get(original_idx)
505 .map(|_| mean_val * col_sum)
506 })
507 .sum()
508 })
509 .collect();
510
511 let row_updates: Vec<(usize, Vec<T>)> = (0..rows)
512 .into_par_iter()
513 .map(|row| {
514 let mut row_result = vec![T::zero(); dense_cols];
515
516 for j in major_offsets[row]..major_offsets[row + 1] {
517 let col = minor_indices[j];
518 if let Some(masked_col) = self.original_to_masked[col] {
519 let val = values[j];
520
521 for c in 0..dense_cols {
522 row_result[c] += val * dense[(masked_col, c)];
523 }
524 }
525 }
526
527 for c in 0..dense_cols {
528 row_result[c] -= mean_adjustments[c];
529 }
530
531 (row, row_result)
532 })
533 .collect();
534
535 for (row, row_values) in row_updates {
536 for c in 0..dense_cols {
537 result[(row, c)] = row_values[c];
538 }
539 }
540 } else {
541 let nrows = self.matrix.nrows();
542 let ncols = self.ncols();
543
544 for i in 0..result.nrows() {
546 for j in 0..result.ncols() {
547 result[(i, j)] = T::zero();
548 }
549 }
550
551 let chunk_size = determine_chunk_size(nrows);
553
554 let partial_results: Vec<DMatrix<T>> = (0..nrows.div_ceil(chunk_size))
556 .into_par_iter()
557 .map(|chunk_idx| {
558 let start = chunk_idx * chunk_size;
559 let end = std::cmp::min(start + chunk_size, nrows);
560
561 let mut local_result = DMatrix::<T>::zeros(ncols, dense_cols);
562
563 for i in start..end {
564 for j in major_offsets[i]..major_offsets[i + 1] {
565 let col = minor_indices[j];
566 if let Some(masked_col) = self.original_to_masked[col] {
567 let sparse_val = values[j];
568
569 for c in 0..dense_cols {
570 local_result[(masked_col, c)] += sparse_val * dense[(i, c)];
571 }
572 }
573 }
574 }
575
576 let chunk_fraction =
578 T::from_f64((end - start) as f64 / dense_rows as f64).unwrap();
579
580 for masked_col in 0..ncols {
581 if masked_col < means.len() {
582 let mean = means[masked_col];
583 for c in 0..dense_cols {
584 local_result[(masked_col, c)] -=
585 mean * col_sums[c] * chunk_fraction;
586 }
587 }
588 }
589
590 local_result
591 })
592 .collect();
593
594 for local_result in partial_results {
595 const BLOCK_SIZE: usize = 32;
596
597 for r_block in 0..ncols.div_ceil(BLOCK_SIZE) {
598 let r_start = r_block * BLOCK_SIZE;
599 let r_end = std::cmp::min(r_start + BLOCK_SIZE, ncols);
600
601 for c_block in 0..dense_cols.div_ceil(BLOCK_SIZE) {
602 let c_start = c_block * BLOCK_SIZE;
603 let c_end = std::cmp::min(c_start + BLOCK_SIZE, dense_cols);
604
605 for r in r_start..r_end {
606 for c in c_start..c_end {
607 result[(r, c)] += local_result[(r, c)];
608 }
609 }
610 }
611 }
612 }
613 }
614 }
615
616 fn multiply_transposed_by_dense(&self, q: &DMatrix<T>, result: &mut DMatrix<T>) {
617 let q_rows = q.nrows();
618 let q_cols = q.ncols();
619 let masked_cols = self.ncols();
620
621 assert_eq!(
622 q_rows,
623 self.nrows(),
624 "Q matrix has incompatible row count: expected {}, got {}",
625 self.nrows(),
626 q_rows
627 );
628 assert_eq!(
629 result.nrows(),
630 q_cols,
631 "Result matrix has incompatible row count: expected {}, got {}",
632 q_cols,
633 result.nrows()
634 );
635 assert_eq!(
636 result.ncols(),
637 masked_cols,
638 "Result matrix has incompatible column count: expected {}, got {}",
639 masked_cols,
640 result.ncols()
641 );
642
643 for i in 0..result.nrows() {
645 for j in 0..result.ncols() {
646 result[(i, j)] = T::zero();
647 }
648 }
649
650 let (major_offsets, minor_indices, values) = self.matrix.csr_data();
651 let nrows = self.matrix.nrows();
652 let chunk_size = determine_chunk_size(nrows);
653
654 if self.uses_all_columns() && (nrows < 1000 && self.matrix.ncols() < 1000) {
655 let partial_results: Vec<DMatrix<T>> = (0..nrows.div_ceil(chunk_size))
657 .into_par_iter()
658 .map(|chunk_idx| {
659 let start = chunk_idx * chunk_size;
660 let end = (start + chunk_size).min(nrows);
661 let mut local_result = DMatrix::<T>::zeros(q_cols, masked_cols);
662
663 for row in start..end {
664 for idx in major_offsets[row]..major_offsets[row + 1] {
666 let col = minor_indices[idx];
667 let sparse_val = values[idx];
668
669 for q_col in 0..q_cols {
671 local_result[(q_col, col)] += q[(row, q_col)] * sparse_val;
672 }
673 }
674 }
675
676 local_result
677 })
678 .collect();
679
680 for local_result in partial_results {
682 for r in 0..q_cols {
683 for c in 0..masked_cols {
684 let val = local_result[(r, c)];
685 if !val.is_zero() {
686 result[(r, c)] += val;
687 }
688 }
689 }
690 }
691 } else {
692 let partial_results: Vec<DMatrix<T>> = (0..nrows.div_ceil(chunk_size))
694 .into_par_iter()
695 .map(|chunk_idx| {
696 let start = chunk_idx * chunk_size;
697 let end = (start + chunk_size).min(nrows);
698 let mut local_result = DMatrix::<T>::zeros(q_cols, masked_cols);
699
700 for row in start..end {
701 for idx in major_offsets[row]..major_offsets[row + 1] {
703 let original_col = minor_indices[idx];
704
705 if let Some(masked_col) = self.original_to_masked[original_col] {
707 let sparse_val = values[idx];
708
709 for q_col in 0..q_cols {
711 local_result[(q_col, masked_col)] += q[(row, q_col)] * sparse_val;
712 }
713 }
714 }
715 }
716
717 local_result
718 })
719 .collect();
720
721 for local_result in partial_results {
723 for r in 0..q_cols {
724 for c in 0..masked_cols {
725 let val = local_result[(r, c)];
726 if !val.is_zero() {
727 result[(r, c)] += val;
728 }
729 }
730 }
731 }
732 }
733 }
734
735 fn multiply_transposed_by_dense_centered(
736 &self,
737 q: &DMatrix<T>,
738 result: &mut DMatrix<T>,
739 means: &DVector<T>,
740 ) {
741 let q_rows = q.nrows();
742 let q_cols = q.ncols();
743 let masked_cols = self.ncols();
744
745 assert_eq!(
746 q_rows,
747 self.nrows(),
748 "Q matrix has incompatible row count: expected {}, got {}",
749 self.nrows(),
750 q_rows
751 );
752 assert_eq!(
753 result.nrows(),
754 q_cols,
755 "Result matrix has incompatible row count: expected {}, got {}",
756 q_cols,
757 result.nrows()
758 );
759 assert_eq!(
760 result.ncols(),
761 masked_cols,
762 "Result matrix has incompatible column count: expected {}, got {}",
763 masked_cols,
764 result.ncols()
765 );
766 assert_eq!(
767 means.len(),
768 masked_cols,
769 "Means vector has incompatible length: expected {}, got {}",
770 masked_cols,
771 means.len()
772 );
773
774 for i in 0..result.nrows() {
776 for j in 0..result.ncols() {
777 result[(i, j)] = T::zero();
778 }
779 }
780
781 let (major_offsets, minor_indices, values) = self.matrix.csr_data();
782
783 let q_col_sums: Vec<T> = (0..q_cols)
785 .into_par_iter()
786 .map(|col| {
787 (0..q_rows).map(|row| q[(row, col)]).sum()
788 })
789 .collect();
790
791 let mean_adjustments: Vec<T> = q_col_sums
794 .iter()
795 .enumerate()
796 .map(|(q_col, &q_sum)| {
797 means
798 .iter()
799 .enumerate()
800 .map(|(masked_col_idx, &mean_val)| {
801 if masked_col_idx < masked_cols {
802 q_sum * mean_val
803 } else {
804 T::zero()
805 }
806 })
807 .sum()
808 })
809 .collect();
810
811 let nrows = self.matrix.nrows();
812 let chunk_size = determine_chunk_size(nrows);
813
814 let partial_results: Vec<DMatrix<T>> = (0..nrows.div_ceil(chunk_size))
816 .into_par_iter()
817 .map(|chunk_idx| {
818 let start = chunk_idx * chunk_size;
819 let end = std::cmp::min(start + chunk_size, nrows);
820
821 let mut local_result = DMatrix::<T>::zeros(q_cols, masked_cols);
822
823 for row in start..end {
824 for idx in major_offsets[row]..major_offsets[row + 1] {
826 let original_col = minor_indices[idx];
827
828 if let Some(masked_col) = self.original_to_masked[original_col] {
830 let sparse_val = values[idx];
831
832 for q_col in 0..q_cols {
834 local_result[(q_col, masked_col)] += q[(row, q_col)] * sparse_val;
835 }
836 }
837 }
838 }
839
840 let chunk_fraction = T::from_f64((end - start) as f64 / q_rows as f64).unwrap();
842
843 for q_col in 0..q_cols {
844 let q_sum = q_col_sums[q_col];
845 for masked_col in 0..masked_cols {
846 local_result[(q_col, masked_col)] -= q_sum * means[masked_col] * chunk_fraction;
847 }
848 }
849
850 local_result
851 })
852 .collect();
853
854 for local_result in partial_results {
856 const BLOCK_SIZE: usize = 64;
857
858 for r_block in 0..q_cols.div_ceil(BLOCK_SIZE) {
859 let r_start = r_block * BLOCK_SIZE;
860 let r_end = std::cmp::min(r_start + BLOCK_SIZE, q_cols);
861
862 for c_block in 0..masked_cols.div_ceil(BLOCK_SIZE) {
863 let c_start = c_block * BLOCK_SIZE;
864 let c_end = std::cmp::min(c_start + BLOCK_SIZE, masked_cols);
865
866 for r in r_start..r_end {
867 for c in c_start..c_end {
868 result[(r, c)] += local_result[(r, c)];
869 }
870 }
871 }
872 }
873 }
874 }
875}
876
877#[cfg(test)]
878mod tests {
879 use super::*;
880 use crate::SMat;
881 use nalgebra_sparse::{coo::CooMatrix, csr::CsrMatrix};
882 use rand::rngs::StdRng;
883 use rand::{Rng, SeedableRng};
884
885 #[test]
886 fn test_masked_matrix() {
887 let mut coo = CooMatrix::<f64>::new(3, 5);
889 coo.push(0, 0, 1.0);
890 coo.push(0, 2, 2.0);
891 coo.push(0, 4, 3.0);
892 coo.push(1, 1, 4.0);
893 coo.push(1, 3, 5.0);
894 coo.push(2, 0, 6.0);
895 coo.push(2, 2, 7.0);
896 coo.push(2, 4, 8.0);
897
898 let csr = CsrMatrix::from(&coo);
899
900 let columns = vec![0, 2, 4];
902 let masked = MaskedCSRMatrix::with_columns(&csr, &columns);
903
904 assert_eq!(masked.nrows(), 3);
906 assert_eq!(masked.ncols(), 3);
907 assert_eq!(masked.nnz(), 6); let svd_result = crate::lanczos::svd(&masked);
911 assert!(svd_result.is_ok());
912 }
913
914 #[test]
915 fn test_masked_vs_physical_subset() {
916 let mut rng = StdRng::seed_from_u64(42);
918
919 let nrows = 14;
921 let ncols = 10;
922 let nnz = 40; let mut coo = CooMatrix::<f64>::new(nrows, ncols);
925
926 for _ in 0..nnz {
928 let row = rng.gen_range(0..nrows);
929 let col = rng.gen_range(0..ncols);
930 let val = rng.gen_range(0.1..10.0);
931
932 coo.push(row, col, val);
934 }
935
936 let csr = CsrMatrix::from(&coo);
938
939 let selected_columns = vec![1, 3, 5, 7];
941
942 let masked_matrix = MaskedCSRMatrix::with_columns(&csr, &selected_columns);
944
945 let mut physical_subset = CooMatrix::<f64>::new(nrows, selected_columns.len());
947
948 let col_map: std::collections::HashMap<usize, usize> = selected_columns
950 .iter()
951 .enumerate()
952 .map(|(new_idx, &old_idx)| (old_idx, new_idx))
953 .collect();
954
955 for (row, col, val) in coo.triplet_iter() {
957 if let Some(&new_col) = col_map.get(&col) {
958 physical_subset.push(row, new_col, *val);
959 }
960 }
961
962 let physical_csr = CsrMatrix::from(&physical_subset);
964
965 assert_eq!(masked_matrix.nrows(), physical_csr.nrows());
967 assert_eq!(masked_matrix.ncols(), physical_csr.ncols());
968 assert_eq!(masked_matrix.nnz(), physical_csr.nnz());
969
970 let svd_masked = crate::lanczos::svd(&masked_matrix).unwrap();
972 let svd_physical = crate::lanczos::svd(&physical_csr).unwrap();
973
974 assert_eq!(svd_masked.d, svd_physical.d);
979
980 let epsilon = 1e-10;
982
983 let mut masked_s = svd_masked.s.to_vec();
985 let mut physical_s = svd_physical.s.to_vec();
986 masked_s.sort_by(|a, b| b.partial_cmp(a).unwrap()); physical_s.sort_by(|a, b| b.partial_cmp(a).unwrap());
988
989 for (m, p) in masked_s.iter().zip(physical_s.iter()) {
990 assert!(
991 (m - p).abs() < epsilon,
992 "Singular values differ: {} vs {}",
993 m,
994 p
995 );
996 }
997
998 }
1001}