1use crate::{determine_chunk_size, SMat, SvdFloat};
2use nalgebra_sparse::na::{DMatrix, DVector};
3use nalgebra_sparse::CsrMatrix;
4use num_traits::Float;
5use rayon::iter::IndexedParallelIterator;
6use rayon::iter::ParallelIterator;
7use rayon::prelude::{
8 IntoParallelIterator, IntoParallelRefIterator, ParallelBridge, ParallelSliceMut,
9};
10use std::fmt::Debug;
11use std::ops::AddAssign;
12
13pub struct MaskedCSRMatrix<'a, T: Float> {
14 matrix: &'a CsrMatrix<T>,
15 column_mask: Vec<bool>,
16 masked_to_original: Vec<usize>,
17 original_to_masked: Vec<Option<usize>>,
18}
19
20impl<'a, T: Float> MaskedCSRMatrix<'a, T> {
21 pub fn new(matrix: &'a CsrMatrix<T>, column_mask: Vec<bool>) -> Self {
22 assert_eq!(
23 column_mask.len(),
24 matrix.ncols(),
25 "Column mask must have the same length as the number of columns in the matrix"
26 );
27
28 let mut masked_to_original = Vec::new();
29 let mut original_to_masked = vec![None; column_mask.len()];
30 let mut masked_index = 0;
31
32 for (i, &is_included) in column_mask.iter().enumerate() {
33 if is_included {
34 masked_to_original.push(i);
35 original_to_masked[i] = Some(masked_index);
36 masked_index += 1;
37 }
38 }
39
40 Self {
41 matrix,
42 column_mask,
43 masked_to_original,
44 original_to_masked,
45 }
46 }
47
48 pub fn with_columns(matrix: &'a CsrMatrix<T>, columns: &[usize]) -> Self {
49 let mut mask = vec![false; matrix.ncols()];
50 for &col in columns {
51 assert!(col < matrix.ncols(), "Column index out of bounds");
52 mask[col] = true;
53 }
54 Self::new(matrix, mask)
55 }
56
57 pub fn uses_all_columns(&self) -> bool {
58 self.masked_to_original.len() == self.matrix.ncols() && self.column_mask.iter().all(|&x| x)
59 }
60
61 pub fn ensure_identical_results_mode(&self) -> bool {
62 let is_small_matrix = self.matrix.nrows() <= 5 && self.matrix.ncols() <= 5;
64 is_small_matrix && self.uses_all_columns()
65 }
66}
67
68impl<
69 'a,
70 T: Float
71 + AddAssign
72 + Sync
73 + Send
74 + std::ops::MulAssign
75 + Debug
76 + 'static
77 + std::iter::Sum
78 + std::ops::SubAssign
79 + num_traits::FromPrimitive,
80 > SMat<T> for MaskedCSRMatrix<'a, T>
81{
82 fn nrows(&self) -> usize {
83 self.matrix.nrows()
84 }
85
86 fn ncols(&self) -> usize {
87 self.masked_to_original.len()
88 }
89
90 fn nnz(&self) -> usize {
91 let (major_offsets, minor_indices, _) = self.matrix.csr_data();
92 let mut count = 0;
93
94 for i in 0..self.matrix.nrows() {
95 for j in major_offsets[i]..major_offsets[i + 1] {
96 let col = minor_indices[j];
97 if self.column_mask[col] {
98 count += 1;
99 }
100 }
101 }
102 count
103 }
104
105 fn svd_opa(&self, x: &[T], y: &mut [T], transposed: bool) {
106 let nrows = if transposed {
107 self.ncols()
108 } else {
109 self.nrows()
110 };
111 let ncols = if transposed {
112 self.nrows()
113 } else {
114 self.ncols()
115 };
116
117 assert_eq!(
118 x.len(),
119 ncols,
120 "svd_opa: x must be A.ncols() in length, x = {}, A.ncols = {}",
121 x.len(),
122 ncols
123 );
124 assert_eq!(
125 y.len(),
126 nrows,
127 "svd_opa: y must be A.nrows() in length, y = {}, A.nrows = {}",
128 y.len(),
129 nrows
130 );
131
132 let (major_offsets, minor_indices, values) = self.matrix.csr_data();
133
134 if self.uses_all_columns() || (self.matrix.nrows() < 1000 && self.matrix.ncols() < 1000) {
135 if !transposed {
137 self.matrix.svd_opa(x, y, false);
139 } else {
140 self.matrix.svd_opa(x, y, true);
142 }
143 return;
144 }
145
146 y.fill(T::zero());
147
148 if !transposed {
149 let valid_indices: Vec<Option<usize>> = (0..self.matrix.ncols())
151 .map(|col| self.original_to_masked[col])
152 .collect();
153
154 let rows = self.matrix.nrows();
156 let chunk_size = std::cmp::max(16, rows / (rayon::current_num_threads() * 2));
157
158 y.par_chunks_mut(chunk_size)
160 .enumerate()
161 .for_each(|(chunk_idx, y_chunk)| {
162 let start_row = chunk_idx * chunk_size;
163 let end_row = (start_row + y_chunk.len()).min(rows);
164
165 for i in start_row..end_row {
166 let row_idx = i - start_row;
167 let mut sum = T::zero();
168
169 let row_start = major_offsets[i];
171 let row_end = major_offsets[i + 1];
172
173 let mut j = row_start;
175 while j + 4 <= row_end {
176 for offset in 0..4 {
177 let idx = j + offset;
178 let col = minor_indices[idx];
179 if let Some(masked_col) = valid_indices[col] {
180 sum += values[idx] * x[masked_col];
181 }
182 }
183 j += 4;
184 }
185
186 while j < row_end {
188 let col = minor_indices[j];
189 if let Some(masked_col) = valid_indices[col] {
190 sum += values[j] * x[masked_col];
191 }
192 j += 1;
193 }
194
195 y_chunk[row_idx] = sum;
196 }
197 });
198 } else {
199 let nrows = self.matrix.nrows();
201 let chunk_size = crate::utils::determine_chunk_size(nrows);
202
203 let results: Vec<Vec<T>> = (0..nrows.div_ceil(chunk_size))
205 .into_par_iter()
206 .map(|chunk_idx| {
207 let start = chunk_idx * chunk_size;
208 let end = (start + chunk_size).min(nrows);
209 let mut local_y = vec![T::zero(); y.len()];
210
211 for i in start..end {
213 let row_val = x[i];
214 if row_val.is_zero() {
215 continue; }
217
218 for j in major_offsets[i]..major_offsets[i + 1] {
219 let col = minor_indices[j];
220 if let Some(masked_col) = self.original_to_masked[col] {
221 local_y[masked_col] += values[j] * row_val;
222 }
223 }
224 }
225 local_y
226 })
227 .collect();
228
229 for local_y in results {
231 for (idx, &val) in local_y.iter().enumerate() {
233 if !val.is_zero() {
234 y[idx] += val;
235 }
236 }
237 }
238 }
239 }
240
241 fn compute_column_means(&self) -> Vec<T> {
242 let rows = self.nrows();
243 let masked_cols = self.ncols();
244 let row_count_recip = T::one() / T::from(rows).unwrap();
245
246 let mut col_sums = vec![T::zero(); masked_cols];
247 let (row_offsets, col_indices, values) = self.matrix.csr_data();
248
249 for i in 0..rows {
250 for j in row_offsets[i]..row_offsets[i + 1] {
251 let original_col = col_indices[j];
252 if let Some(masked_col) = self.original_to_masked[original_col] {
253 col_sums[masked_col] += values[j];
254 }
255 }
256 }
257
258 for j in 0..masked_cols {
260 col_sums[j] *= row_count_recip;
261 }
262
263 col_sums
264 }
265
266 fn multiply_with_dense(
267 &self,
268 dense: &DMatrix<T>,
269 result: &mut DMatrix<T>,
270 transpose_self: bool,
271 ) {
272 let m_rows = if transpose_self {
273 self.ncols()
274 } else {
275 self.nrows()
276 };
277 let m_cols = if transpose_self {
278 self.nrows()
279 } else {
280 self.ncols()
281 };
282
283 assert_eq!(
284 dense.nrows(),
285 m_cols,
286 "Dense matrix has incompatible row count"
287 );
288 assert_eq!(
289 result.nrows(),
290 m_rows,
291 "Result matrix has incompatible row count"
292 );
293 assert_eq!(
294 result.ncols(),
295 dense.ncols(),
296 "Result matrix has incompatible column count"
297 );
298
299 let (major_offsets, minor_indices, values) = self.matrix.csr_data();
306
307 if !transpose_self {
308 let rows = self.matrix.nrows();
309 let dense_cols = dense.ncols();
310
311 let partial_results: Vec<(usize, DMatrix<T>)> = (0..rows)
312 .into_par_iter()
313 .map(|row| {
314 let mut local_result = DMatrix::<T>::zeros(1, dense_cols);
315
316 for j in major_offsets[row]..major_offsets[row + 1] {
317 let col = minor_indices[j];
318 if let Some(masked_col) = self.original_to_masked[col] {
319 let val = values[j];
320
321 for c in 0..dense_cols {
322 local_result[(0, c)] += val * dense[(masked_col, c)];
323 }
324 }
325 }
326
327 (row, local_result)
328 })
329 .collect();
330
331 for (row, local_result) in partial_results {
332 for c in 0..dense_cols {
333 result[(row, c)] = local_result[(0, c)];
334 }
335 }
336 } else {
337 let nrows = self.matrix.nrows();
338 let ncols = self.ncols();
339 let dense_cols = dense.ncols();
340
341 let chunk_size = determine_chunk_size(nrows);
342
343 let partial_results: Vec<DMatrix<T>> = (0..nrows.div_ceil(chunk_size))
344 .into_par_iter()
345 .map(|chunk_idx| {
346 let start = chunk_idx * chunk_size;
347 let end = (start + chunk_size).min(nrows);
348
349 let mut local_result = DMatrix::<T>::zeros(ncols, dense_cols);
350
351 for i in start..end {
352 for j in major_offsets[i]..major_offsets[i + 1] {
353 let col = minor_indices[j];
354 if let Some(masked_col) = self.original_to_masked[col] {
355 let val = values[j];
356
357 for c in 0..dense_cols {
358 local_result[(masked_col, c)] += val * dense[(i, c)];
359 }
360 }
361 }
362 }
363
364 local_result
365 })
366 .collect();
367
368 for local_result in partial_results {
369 for r in 0..ncols {
370 for c in 0..dense_cols {
371 let val = local_result[(r, c)];
372 if !val.is_zero() {
373 result[(r, c)] += val;
374 }
375 }
376 }
377 }
378 }
379 }
380
381 fn multiply_with_dense_centered(
382 &self,
383 dense: &DMatrix<T>,
384 result: &mut DMatrix<T>,
385 transpose_self: bool,
386 means: &DVector<T>,
387 ) {
388 let (major_offsets, minor_indices, values) = self.matrix.csr_data();
389
390 let dense_cols = dense.ncols();
392 let dense_rows = dense.nrows();
393
394 let col_sums: Vec<T> = (0..dense_cols)
396 .into_par_iter()
397 .map(|c| (0..dense_rows).map(|i| dense[(i, c)]).sum())
398 .collect();
399
400 if !transpose_self {
401 let rows = self.matrix.nrows();
402
403 let mean_adjustments: Vec<T> = col_sums
405 .iter()
406 .map(|&col_sum| {
407 means
408 .iter()
409 .enumerate()
410 .filter_map(|(original_idx, &mean_val)| {
411 self.original_to_masked
412 .get(original_idx)
413 .map(|_| mean_val * col_sum)
414 })
415 .sum()
416 })
417 .collect();
418
419 let chunk_size = std::cmp::max(16, rows / (rayon::current_num_threads() * 4));
420
421 let row_updates: Vec<(usize, Vec<T>)> = (0..rows)
422 .into_par_iter()
423 .map(|row| {
424 let mut row_result = vec![T::zero(); dense_cols];
425
426 for j in major_offsets[row]..major_offsets[row + 1] {
427 let col = minor_indices[j];
428 if let Some(masked_col) = self.original_to_masked[col] {
429 let val = values[j];
430
431 for c in 0..dense_cols {
432 row_result[c] += val * dense[(masked_col, c)];
433 }
434 }
435 }
436
437 for c in 0..dense_cols {
438 row_result[c] -= mean_adjustments[c];
439 }
440
441 (row, row_result)
442 })
443 .collect();
444
445 for (row, row_values) in row_updates {
446 for c in 0..dense_cols {
447 result[(row, c)] = row_values[c];
448 }
449 }
450 } else {
451 let nrows = self.matrix.nrows();
452 let ncols = self.ncols();
453
454 for i in 0..result.nrows() {
456 for j in 0..result.ncols() {
457 result[(i, j)] = T::zero();
458 }
459 }
460
461 let chunk_size = determine_chunk_size(nrows);
463
464 let partial_results: Vec<DMatrix<T>> = (0..nrows.div_ceil(chunk_size))
466 .into_par_iter()
467 .map(|chunk_idx| {
468 let start = chunk_idx * chunk_size;
469 let end = std::cmp::min(start + chunk_size, nrows);
470
471 let mut local_result = DMatrix::<T>::zeros(ncols, dense_cols);
472
473 for i in start..end {
474 for j in major_offsets[i]..major_offsets[i + 1] {
475 let col = minor_indices[j];
476 if let Some(masked_col) = self.original_to_masked[col] {
477 let sparse_val = values[j];
478
479 for c in 0..dense_cols {
480 local_result[(masked_col, c)] += sparse_val * dense[(i, c)];
481 }
482 }
483 }
484 }
485
486 let chunk_fraction =
488 T::from_f64((end - start) as f64 / dense_rows as f64).unwrap();
489
490 for masked_col in 0..ncols {
491 if masked_col < means.len() {
492 let mean = means[masked_col];
493 for c in 0..dense_cols {
494 local_result[(masked_col, c)] -=
495 mean * col_sums[c] * chunk_fraction;
496 }
497 }
498 }
499
500 local_result
501 })
502 .collect();
503
504 for local_result in partial_results {
505 const BLOCK_SIZE: usize = 32;
506
507 for r_block in 0..ncols.div_ceil(BLOCK_SIZE) {
508 let r_start = r_block * BLOCK_SIZE;
509 let r_end = std::cmp::min(r_start + BLOCK_SIZE, ncols);
510
511 for c_block in 0..dense_cols.div_ceil(BLOCK_SIZE) {
512 let c_start = c_block * BLOCK_SIZE;
513 let c_end = std::cmp::min(c_start + BLOCK_SIZE, dense_cols);
514
515 for r in r_start..r_end {
516 for c in c_start..c_end {
517 result[(r, c)] += local_result[(r, c)];
518 }
519 }
520 }
521 }
522 }
523 }
524 }
525}
526
527#[cfg(test)]
528mod tests {
529 use super::*;
530 use crate::SMat;
531 use nalgebra_sparse::{coo::CooMatrix, csr::CsrMatrix};
532 use rand::rngs::StdRng;
533 use rand::{Rng, SeedableRng};
534
535 #[test]
536 fn test_masked_matrix() {
537 let mut coo = CooMatrix::<f64>::new(3, 5);
539 coo.push(0, 0, 1.0);
540 coo.push(0, 2, 2.0);
541 coo.push(0, 4, 3.0);
542 coo.push(1, 1, 4.0);
543 coo.push(1, 3, 5.0);
544 coo.push(2, 0, 6.0);
545 coo.push(2, 2, 7.0);
546 coo.push(2, 4, 8.0);
547
548 let csr = CsrMatrix::from(&coo);
549
550 let columns = vec![0, 2, 4];
552 let masked = MaskedCSRMatrix::with_columns(&csr, &columns);
553
554 assert_eq!(masked.nrows(), 3);
556 assert_eq!(masked.ncols(), 3);
557 assert_eq!(masked.nnz(), 6); let svd_result = crate::lanczos::svd(&masked);
561 assert!(svd_result.is_ok());
562 }
563
564 #[test]
565 fn test_masked_vs_physical_subset() {
566 let mut rng = StdRng::seed_from_u64(42);
568
569 let nrows = 14;
571 let ncols = 10;
572 let nnz = 40; let mut coo = CooMatrix::<f64>::new(nrows, ncols);
575
576 for _ in 0..nnz {
578 let row = rng.gen_range(0..nrows);
579 let col = rng.gen_range(0..ncols);
580 let val = rng.gen_range(0.1..10.0);
581
582 coo.push(row, col, val);
584 }
585
586 let csr = CsrMatrix::from(&coo);
588
589 let selected_columns = vec![1, 3, 5, 7];
591
592 let masked_matrix = MaskedCSRMatrix::with_columns(&csr, &selected_columns);
594
595 let mut physical_subset = CooMatrix::<f64>::new(nrows, selected_columns.len());
597
598 let col_map: std::collections::HashMap<usize, usize> = selected_columns
600 .iter()
601 .enumerate()
602 .map(|(new_idx, &old_idx)| (old_idx, new_idx))
603 .collect();
604
605 for (row, col, val) in coo.triplet_iter() {
607 if let Some(&new_col) = col_map.get(&col) {
608 physical_subset.push(row, new_col, *val);
609 }
610 }
611
612 let physical_csr = CsrMatrix::from(&physical_subset);
614
615 assert_eq!(masked_matrix.nrows(), physical_csr.nrows());
617 assert_eq!(masked_matrix.ncols(), physical_csr.ncols());
618 assert_eq!(masked_matrix.nnz(), physical_csr.nnz());
619
620 let svd_masked = crate::lanczos::svd(&masked_matrix).unwrap();
622 let svd_physical = crate::lanczos::svd(&physical_csr).unwrap();
623
624 assert_eq!(svd_masked.d, svd_physical.d);
629
630 let epsilon = 1e-10;
632
633 let mut masked_s = svd_masked.s.to_vec();
635 let mut physical_s = svd_physical.s.to_vec();
636 masked_s.sort_by(|a, b| b.partial_cmp(a).unwrap()); physical_s.sort_by(|a, b| b.partial_cmp(a).unwrap());
638
639 for (m, p) in masked_s.iter().zip(physical_s.iter()) {
640 assert!(
641 (m - p).abs() < epsilon,
642 "Singular values differ: {} vs {}",
643 m,
644 p
645 );
646 }
647
648 }
651}