single_svdlib/lanczos/
masked.rs

1use crate::{determine_chunk_size, SMat, SvdFloat};
2use nalgebra_sparse::na::{DMatrix, DVector};
3use nalgebra_sparse::CsrMatrix;
4use num_traits::Float;
5use rayon::iter::IndexedParallelIterator;
6use rayon::iter::ParallelIterator;
7use rayon::prelude::{
8    IntoParallelIterator, IntoParallelRefIterator, ParallelBridge, ParallelSliceMut,
9};
10use std::fmt::Debug;
11use std::ops::AddAssign;
12
13pub struct MaskedCSRMatrix<'a, T: Float> {
14    matrix: &'a CsrMatrix<T>,
15    column_mask: Vec<bool>,
16    masked_to_original: Vec<usize>,
17    original_to_masked: Vec<Option<usize>>,
18}
19
20impl<'a, T: Float> MaskedCSRMatrix<'a, T> {
21    pub fn new(matrix: &'a CsrMatrix<T>, column_mask: Vec<bool>) -> Self {
22        assert_eq!(
23            column_mask.len(),
24            matrix.ncols(),
25            "Column mask must have the same length as the number of columns in the matrix"
26        );
27
28        let mut masked_to_original = Vec::new();
29        let mut original_to_masked = vec![None; column_mask.len()];
30        let mut masked_index = 0;
31
32        for (i, &is_included) in column_mask.iter().enumerate() {
33            if is_included {
34                masked_to_original.push(i);
35                original_to_masked[i] = Some(masked_index);
36                masked_index += 1;
37            }
38        }
39
40        Self {
41            matrix,
42            column_mask,
43            masked_to_original,
44            original_to_masked,
45        }
46    }
47
48    pub fn with_columns(matrix: &'a CsrMatrix<T>, columns: &[usize]) -> Self {
49        let mut mask = vec![false; matrix.ncols()];
50        for &col in columns {
51            assert!(col < matrix.ncols(), "Column index out of bounds");
52            mask[col] = true;
53        }
54        Self::new(matrix, mask)
55    }
56
57    pub fn uses_all_columns(&self) -> bool {
58        self.masked_to_original.len() == self.matrix.ncols() && self.column_mask.iter().all(|&x| x)
59    }
60
61    pub fn ensure_identical_results_mode(&self) -> bool {
62        // For very small matrices where precision is critical
63        let is_small_matrix = self.matrix.nrows() <= 5 && self.matrix.ncols() <= 5;
64        is_small_matrix && self.uses_all_columns()
65    }
66}
67
68impl<
69        'a,
70        T: Float
71            + AddAssign
72            + Sync
73            + Send
74            + std::ops::MulAssign
75            + Debug
76            + 'static
77            + std::iter::Sum
78            + std::ops::SubAssign
79            + num_traits::FromPrimitive,
80    > SMat<T> for MaskedCSRMatrix<'a, T>
81{
82    fn nrows(&self) -> usize {
83        self.matrix.nrows()
84    }
85
86    fn ncols(&self) -> usize {
87        self.masked_to_original.len()
88    }
89
90    fn nnz(&self) -> usize {
91        let (major_offsets, minor_indices, _) = self.matrix.csr_data();
92        let mut count = 0;
93
94        for i in 0..self.matrix.nrows() {
95            for j in major_offsets[i]..major_offsets[i + 1] {
96                let col = minor_indices[j];
97                if self.column_mask[col] {
98                    count += 1;
99                }
100            }
101        }
102        count
103    }
104
105    fn svd_opa(&self, x: &[T], y: &mut [T], transposed: bool) {
106        let nrows = if transposed {
107            self.ncols()
108        } else {
109            self.nrows()
110        };
111        let ncols = if transposed {
112            self.nrows()
113        } else {
114            self.ncols()
115        };
116
117        assert_eq!(
118            x.len(),
119            ncols,
120            "svd_opa: x must be A.ncols() in length, x = {}, A.ncols = {}",
121            x.len(),
122            ncols
123        );
124        assert_eq!(
125            y.len(),
126            nrows,
127            "svd_opa: y must be A.nrows() in length, y = {}, A.nrows = {}",
128            y.len(),
129            nrows
130        );
131
132        let (major_offsets, minor_indices, values) = self.matrix.csr_data();
133
134        if self.uses_all_columns() || (self.matrix.nrows() < 1000 && self.matrix.ncols() < 1000) {
135            // Fast path for unmasked matrices or small matrices
136            if !transposed {
137                // A * x calculation
138                self.matrix.svd_opa(x, y, false);
139            } else {
140                // A^T * x calculation
141                self.matrix.svd_opa(x, y, true);
142            }
143            return;
144        }
145
146        y.fill(T::zero());
147
148        if !transposed {
149            // A * x calculation
150            let valid_indices: Vec<Option<usize>> = (0..self.matrix.ncols())
151                .map(|col| self.original_to_masked[col])
152                .collect();
153
154            // Parallelization parameters
155            let rows = self.matrix.nrows();
156            let chunk_size = std::cmp::max(16, rows / (rayon::current_num_threads() * 2));
157
158            // Process in parallel chunks
159            y.par_chunks_mut(chunk_size)
160                .enumerate()
161                .for_each(|(chunk_idx, y_chunk)| {
162                    let start_row = chunk_idx * chunk_size;
163                    let end_row = (start_row + y_chunk.len()).min(rows);
164
165                    for i in start_row..end_row {
166                        let row_idx = i - start_row;
167                        let mut sum = T::zero();
168
169                        // Process row in blocks of 16 elements for better vectorization
170                        let row_start = major_offsets[i];
171                        let row_end = major_offsets[i + 1];
172
173                        // Unroll the loop by 4 for better instruction-level parallelism
174                        let mut j = row_start;
175                        while j + 4 <= row_end {
176                            for offset in 0..4 {
177                                let idx = j + offset;
178                                let col = minor_indices[idx];
179                                if let Some(masked_col) = valid_indices[col] {
180                                    sum += values[idx] * x[masked_col];
181                                }
182                            }
183                            j += 4;
184                        }
185
186                        // Handle remaining elements
187                        while j < row_end {
188                            let col = minor_indices[j];
189                            if let Some(masked_col) = valid_indices[col] {
190                                sum += values[j] * x[masked_col];
191                            }
192                            j += 1;
193                        }
194
195                        y_chunk[row_idx] = sum;
196                    }
197                });
198        } else {
199            // A^T * x calculation
200            let nrows = self.matrix.nrows();
201            let chunk_size = crate::utils::determine_chunk_size(nrows);
202
203            // Create thread-local partial results and combine at the end
204            let results: Vec<Vec<T>> = (0..nrows.div_ceil(chunk_size))
205                .into_par_iter()
206                .map(|chunk_idx| {
207                    let start = chunk_idx * chunk_size;
208                    let end = (start + chunk_size).min(nrows);
209                    let mut local_y = vec![T::zero(); y.len()];
210
211                    // Process a chunk of rows
212                    for i in start..end {
213                        let row_val = x[i];
214                        if row_val.is_zero() {
215                            continue; // Skip zero values for performance
216                        }
217
218                        for j in major_offsets[i]..major_offsets[i + 1] {
219                            let col = minor_indices[j];
220                            if let Some(masked_col) = self.original_to_masked[col] {
221                                local_y[masked_col] += values[j] * row_val;
222                            }
223                        }
224                    }
225                    local_y
226                })
227                .collect();
228
229            // Combine results efficiently
230            for local_y in results {
231                // Only update non-zero elements to reduce memory traffic
232                for (idx, &val) in local_y.iter().enumerate() {
233                    if !val.is_zero() {
234                        y[idx] += val;
235                    }
236                }
237            }
238        }
239    }
240
241    fn compute_column_means(&self) -> Vec<T> {
242        let rows = self.nrows();
243        let masked_cols = self.ncols();
244        let row_count_recip = T::one() / T::from(rows).unwrap();
245
246        let mut col_sums = vec![T::zero(); masked_cols];
247        let (row_offsets, col_indices, values) = self.matrix.csr_data();
248
249        for i in 0..rows {
250            for j in row_offsets[i]..row_offsets[i + 1] {
251                let original_col = col_indices[j];
252                if let Some(masked_col) = self.original_to_masked[original_col] {
253                    col_sums[masked_col] += values[j];
254                }
255            }
256        }
257
258        // Convert to means
259        for j in 0..masked_cols {
260            col_sums[j] *= row_count_recip;
261        }
262
263        col_sums
264    }
265
266    fn multiply_with_dense(
267        &self,
268        dense: &DMatrix<T>,
269        result: &mut DMatrix<T>,
270        transpose_self: bool,
271    ) {
272        let m_rows = if transpose_self {
273            self.ncols()
274        } else {
275            self.nrows()
276        };
277        let m_cols = if transpose_self {
278            self.nrows()
279        } else {
280            self.ncols()
281        };
282
283        assert_eq!(
284            dense.nrows(),
285            m_cols,
286            "Dense matrix has incompatible row count"
287        );
288        assert_eq!(
289            result.nrows(),
290            m_rows,
291            "Result matrix has incompatible row count"
292        );
293        assert_eq!(
294            result.ncols(),
295            dense.ncols(),
296            "Result matrix has incompatible column count"
297        );
298
299        // Determine if we can use optimized path
300        //if self.ensure_identical_results_mode() {
301        // For small matrices, use the default implementation
302        //    return <Self as SMat<T>>::multiply_matrix(self, dense, result, transpose_self);
303        //}
304
305        let (major_offsets, minor_indices, values) = self.matrix.csr_data();
306
307        if !transpose_self {
308            let rows = self.matrix.nrows();
309            let dense_cols = dense.ncols();
310
311            let partial_results: Vec<(usize, DMatrix<T>)> = (0..rows)
312                .into_par_iter()
313                .map(|row| {
314                    let mut local_result = DMatrix::<T>::zeros(1, dense_cols);
315
316                    for j in major_offsets[row]..major_offsets[row + 1] {
317                        let col = minor_indices[j];
318                        if let Some(masked_col) = self.original_to_masked[col] {
319                            let val = values[j];
320
321                            for c in 0..dense_cols {
322                                local_result[(0, c)] += val * dense[(masked_col, c)];
323                            }
324                        }
325                    }
326
327                    (row, local_result)
328                })
329                .collect();
330
331            for (row, local_result) in partial_results {
332                for c in 0..dense_cols {
333                    result[(row, c)] = local_result[(0, c)];
334                }
335            }
336        } else {
337            let nrows = self.matrix.nrows();
338            let ncols = self.ncols();
339            let dense_cols = dense.ncols();
340
341            let chunk_size = determine_chunk_size(nrows);
342
343            let partial_results: Vec<DMatrix<T>> = (0..nrows.div_ceil(chunk_size))
344                .into_par_iter()
345                .map(|chunk_idx| {
346                    let start = chunk_idx * chunk_size;
347                    let end = (start + chunk_size).min(nrows);
348
349                    let mut local_result = DMatrix::<T>::zeros(ncols, dense_cols);
350
351                    for i in start..end {
352                        for j in major_offsets[i]..major_offsets[i + 1] {
353                            let col = minor_indices[j];
354                            if let Some(masked_col) = self.original_to_masked[col] {
355                                let val = values[j];
356
357                                for c in 0..dense_cols {
358                                    local_result[(masked_col, c)] += val * dense[(i, c)];
359                                }
360                            }
361                        }
362                    }
363
364                    local_result
365                })
366                .collect();
367
368            for local_result in partial_results {
369                for r in 0..ncols {
370                    for c in 0..dense_cols {
371                        let val = local_result[(r, c)];
372                        if !val.is_zero() {
373                            result[(r, c)] += val;
374                        }
375                    }
376                }
377            }
378        }
379    }
380
381    fn multiply_with_dense_centered(
382        &self,
383        dense: &DMatrix<T>,
384        result: &mut DMatrix<T>,
385        transpose_self: bool,
386        means: &DVector<T>,
387    ) {
388        let (major_offsets, minor_indices, values) = self.matrix.csr_data();
389
390        // Pre-compute column sums for the dense matrix - do this once
391        let dense_cols = dense.ncols();
392        let dense_rows = dense.nrows();
393
394        // Pre-compute all column sums to avoid redundant calculations
395        let col_sums: Vec<T> = (0..dense_cols)
396            .into_par_iter()
397            .map(|c| (0..dense_rows).map(|i| dense[(i, c)]).sum())
398            .collect();
399
400        if !transpose_self {
401            let rows = self.matrix.nrows();
402
403            // Pre-compute mean adjustments for each column
404            let mean_adjustments: Vec<T> = col_sums
405                .iter()
406                .map(|&col_sum| {
407                    means
408                        .iter()
409                        .enumerate()
410                        .filter_map(|(original_idx, &mean_val)| {
411                            self.original_to_masked
412                                .get(original_idx)
413                                .map(|_| mean_val * col_sum)
414                        })
415                        .sum()
416                })
417                .collect();
418
419            let chunk_size = std::cmp::max(16, rows / (rayon::current_num_threads() * 4));
420
421            let row_updates: Vec<(usize, Vec<T>)> = (0..rows)
422                .into_par_iter()
423                .map(|row| {
424                    let mut row_result = vec![T::zero(); dense_cols];
425
426                    for j in major_offsets[row]..major_offsets[row + 1] {
427                        let col = minor_indices[j];
428                        if let Some(masked_col) = self.original_to_masked[col] {
429                            let val = values[j];
430
431                            for c in 0..dense_cols {
432                                row_result[c] += val * dense[(masked_col, c)];
433                            }
434                        }
435                    }
436
437                    for c in 0..dense_cols {
438                        row_result[c] -= mean_adjustments[c];
439                    }
440
441                    (row, row_result)
442                })
443                .collect();
444
445            for (row, row_values) in row_updates {
446                for c in 0..dense_cols {
447                    result[(row, c)] = row_values[c];
448                }
449            }
450        } else {
451            let nrows = self.matrix.nrows();
452            let ncols = self.ncols();
453
454            // Clear the result matrix first
455            for i in 0..result.nrows() {
456                for j in 0..result.ncols() {
457                    result[(i, j)] = T::zero();
458                }
459            }
460
461            // Choose optimal chunk size
462            let chunk_size = determine_chunk_size(nrows);
463
464            // Compute partial results in parallel
465            let partial_results: Vec<DMatrix<T>> = (0..nrows.div_ceil(chunk_size))
466                .into_par_iter()
467                .map(|chunk_idx| {
468                    let start = chunk_idx * chunk_size;
469                    let end = std::cmp::min(start + chunk_size, nrows);
470
471                    let mut local_result = DMatrix::<T>::zeros(ncols, dense_cols);
472
473                    for i in start..end {
474                        for j in major_offsets[i]..major_offsets[i + 1] {
475                            let col = minor_indices[j];
476                            if let Some(masked_col) = self.original_to_masked[col] {
477                                let sparse_val = values[j];
478
479                                for c in 0..dense_cols {
480                                    local_result[(masked_col, c)] += sparse_val * dense[(i, c)];
481                                }
482                            }
483                        }
484                    }
485
486                    // Apply mean adjustment for this chunk
487                    let chunk_fraction =
488                        T::from_f64((end - start) as f64 / dense_rows as f64).unwrap();
489
490                    for masked_col in 0..ncols {
491                        if masked_col < means.len() {
492                            let mean = means[masked_col];
493                            for c in 0..dense_cols {
494                                local_result[(masked_col, c)] -=
495                                    mean * col_sums[c] * chunk_fraction;
496                            }
497                        }
498                    }
499
500                    local_result
501                })
502                .collect();
503
504            for local_result in partial_results {
505                const BLOCK_SIZE: usize = 32;
506
507                for r_block in 0..ncols.div_ceil(BLOCK_SIZE) {
508                    let r_start = r_block * BLOCK_SIZE;
509                    let r_end = std::cmp::min(r_start + BLOCK_SIZE, ncols);
510
511                    for c_block in 0..dense_cols.div_ceil(BLOCK_SIZE) {
512                        let c_start = c_block * BLOCK_SIZE;
513                        let c_end = std::cmp::min(c_start + BLOCK_SIZE, dense_cols);
514
515                        for r in r_start..r_end {
516                            for c in c_start..c_end {
517                                result[(r, c)] += local_result[(r, c)];
518                            }
519                        }
520                    }
521                }
522            }
523        }
524    }
525}
526
527#[cfg(test)]
528mod tests {
529    use super::*;
530    use crate::SMat;
531    use nalgebra_sparse::{coo::CooMatrix, csr::CsrMatrix};
532    use rand::rngs::StdRng;
533    use rand::{Rng, SeedableRng};
534
535    #[test]
536    fn test_masked_matrix() {
537        // Create a test matrix
538        let mut coo = CooMatrix::<f64>::new(3, 5);
539        coo.push(0, 0, 1.0);
540        coo.push(0, 2, 2.0);
541        coo.push(0, 4, 3.0);
542        coo.push(1, 1, 4.0);
543        coo.push(1, 3, 5.0);
544        coo.push(2, 0, 6.0);
545        coo.push(2, 2, 7.0);
546        coo.push(2, 4, 8.0);
547
548        let csr = CsrMatrix::from(&coo);
549
550        // Create a masked matrix with columns 0, 2, 4
551        let columns = vec![0, 2, 4];
552        let masked = MaskedCSRMatrix::with_columns(&csr, &columns);
553
554        // Check dimensions
555        assert_eq!(masked.nrows(), 3);
556        assert_eq!(masked.ncols(), 3);
557        assert_eq!(masked.nnz(), 6); // Only entries in the selected columns
558
559        // Test SVD on the masked matrix
560        let svd_result = crate::lanczos::svd(&masked);
561        assert!(svd_result.is_ok());
562    }
563
564    #[test]
565    fn test_masked_vs_physical_subset() {
566        // Create a fixed seed for reproducible tests
567        let mut rng = StdRng::seed_from_u64(42);
568
569        // Generate a random matrix (5x8)
570        let nrows = 14;
571        let ncols = 10;
572        let nnz = 40; // Number of non-zero elements
573
574        let mut coo = CooMatrix::<f64>::new(nrows, ncols);
575
576        // Fill with random non-zero values
577        for _ in 0..nnz {
578            let row = rng.gen_range(0..nrows);
579            let col = rng.gen_range(0..ncols);
580            let val = rng.gen_range(0.1..10.0);
581
582            // Note: CooMatrix will overwrite if the position already has a value
583            coo.push(row, col, val);
584        }
585
586        // Convert to CSR which is what our masked implementation uses
587        let csr = CsrMatrix::from(&coo);
588
589        // Select a subset of columns (e.g., columns 1, 3, 5, 7)
590        let selected_columns = vec![1, 3, 5, 7];
591
592        // Create the masked matrix view
593        let masked_matrix = MaskedCSRMatrix::with_columns(&csr, &selected_columns);
594
595        // Create a physical copy with just those columns
596        let mut physical_subset = CooMatrix::<f64>::new(nrows, selected_columns.len());
597
598        // Map original column indices to new column indices
599        let col_map: std::collections::HashMap<usize, usize> = selected_columns
600            .iter()
601            .enumerate()
602            .map(|(new_idx, &old_idx)| (old_idx, new_idx))
603            .collect();
604
605        // Copy the values for the selected columns
606        for (row, col, val) in coo.triplet_iter() {
607            if let Some(&new_col) = col_map.get(&col) {
608                physical_subset.push(row, new_col, *val);
609            }
610        }
611
612        // Convert to CSR for SVD
613        let physical_csr = CsrMatrix::from(&physical_subset);
614
615        // Compare dimensions and nnz
616        assert_eq!(masked_matrix.nrows(), physical_csr.nrows());
617        assert_eq!(masked_matrix.ncols(), physical_csr.ncols());
618        assert_eq!(masked_matrix.nnz(), physical_csr.nnz());
619
620        // Perform SVD on both
621        let svd_masked = crate::lanczos::svd(&masked_matrix).unwrap();
622        let svd_physical = crate::lanczos::svd(&physical_csr).unwrap();
623
624        // Compare SVD results - they should be very close but not exactly the same
625        // due to potential differences in numerical computation
626
627        // Check dimension (rank)
628        assert_eq!(svd_masked.d, svd_physical.d);
629
630        // Basic tolerance for floating point comparisons
631        let epsilon = 1e-10;
632
633        // Check singular values (may be in different order, so we sort them)
634        let mut masked_s = svd_masked.s.to_vec();
635        let mut physical_s = svd_physical.s.to_vec();
636        masked_s.sort_by(|a, b| b.partial_cmp(a).unwrap()); // Sort in descending order
637        physical_s.sort_by(|a, b| b.partial_cmp(a).unwrap());
638
639        for (m, p) in masked_s.iter().zip(physical_s.iter()) {
640            assert!(
641                (m - p).abs() < epsilon,
642                "Singular values differ: {} vs {}",
643                m,
644                p
645            );
646        }
647
648        // Note: Comparing singular vectors is more complex due to potential sign flips
649        // and different ordering, so we'll skip that level of detailed comparison
650    }
651}