single_svdlib/lanczos/
masked.rs

1use crate::utils::determine_chunk_size;
2use crate::{SMat, SvdFloat};
3use nalgebra_sparse::CsrMatrix;
4use num_traits::Float;
5use rayon::iter::IndexedParallelIterator;
6use rayon::iter::ParallelIterator;
7use rayon::prelude::{IntoParallelIterator, ParallelBridge, ParallelSliceMut};
8use std::ops::AddAssign;
9
10pub struct MaskedCSRMatrix<'a, T: Float> {
11    matrix: &'a CsrMatrix<T>,
12    column_mask: Vec<bool>,
13    masked_to_original: Vec<usize>,
14    original_to_masked: Vec<Option<usize>>,
15}
16
17impl<'a, T: Float> MaskedCSRMatrix<'a, T> {
18    pub fn new(matrix: &'a CsrMatrix<T>, column_mask: Vec<bool>) -> Self {
19        assert_eq!(
20            column_mask.len(),
21            matrix.ncols(),
22            "Column mask must have the same length as the number of columns in the matrix"
23        );
24
25        let mut masked_to_original = Vec::new();
26        let mut original_to_masked = vec![None; column_mask.len()];
27        let mut masked_index = 0;
28
29        for (i, &is_included) in column_mask.iter().enumerate() {
30            if is_included {
31                masked_to_original.push(i);
32                original_to_masked[i] = Some(masked_index);
33                masked_index += 1;
34            }
35        }
36
37        Self {
38            matrix,
39            column_mask,
40            masked_to_original,
41            original_to_masked,
42        }
43    }
44
45    pub fn with_columns(matrix: &'a CsrMatrix<T>, columns: &[usize]) -> Self {
46        let mut mask = vec![false; matrix.ncols()];
47        for &col in columns {
48            assert!(col < matrix.ncols(), "Column index out of bounds");
49            mask[col] = true;
50        }
51        Self::new(matrix, mask)
52    }
53
54    pub fn uses_all_columns(&self) -> bool {
55        self.masked_to_original.len() == self.matrix.ncols() && self.column_mask.iter().all(|&x| x)
56    }
57
58    pub fn ensure_identical_results_mode(&self) -> bool {
59        // For very small matrices where precision is critical
60        let is_small_matrix = self.matrix.nrows() <= 5 && self.matrix.ncols() <= 5;
61        is_small_matrix && self.uses_all_columns()
62    }
63}
64
65impl<'a, T: Float + AddAssign + Sync + Send + std::ops::MulAssign> SMat<T> for MaskedCSRMatrix<'a, T> {
66    fn nrows(&self) -> usize {
67        self.matrix.nrows()
68    }
69
70    fn ncols(&self) -> usize {
71        self.masked_to_original.len()
72    }
73
74    fn nnz(&self) -> usize {
75        let (major_offsets, minor_indices, _) = self.matrix.csr_data();
76        let mut count = 0;
77
78        for i in 0..self.matrix.nrows() {
79            for j in major_offsets[i]..major_offsets[i + 1] {
80                let col = minor_indices[j];
81                if self.column_mask[col] {
82                    count += 1;
83                }
84            }
85        }
86        count
87    }
88
89    fn svd_opa(&self, x: &[T], y: &mut [T], transposed: bool) {
90        let nrows = if transposed {
91            self.ncols()
92        } else {
93            self.nrows()
94        };
95        let ncols = if transposed {
96            self.nrows()
97        } else {
98            self.ncols()
99        };
100
101        assert_eq!(
102            x.len(),
103            ncols,
104            "svd_opa: x must be A.ncols() in length, x = {}, A.ncols = {}",
105            x.len(),
106            ncols
107        );
108        assert_eq!(
109            y.len(),
110            nrows,
111            "svd_opa: y must be A.nrows() in length, y = {}, A.nrows = {}",
112            y.len(),
113            nrows
114        );
115
116        let (major_offsets, minor_indices, values) = self.matrix.csr_data();
117
118        y.fill(T::zero());
119
120        if !transposed {
121            // A * x calculation
122            let row_count = self.matrix.nrows();
123            let (major_offsets, minor_indices, values) = self.matrix.csr_data();
124
125            let chunk_size = std::cmp::max(16, row_count / (rayon::current_num_threads() * 2));
126
127            let mut valid_indices = Vec::with_capacity(self.matrix.ncols());
128            for col in 0..self.matrix.ncols() {
129                valid_indices.push(self.original_to_masked[col]);
130            }
131
132            y.par_chunks_mut(chunk_size)
133                .enumerate()
134                .for_each(|(chunk_idx, y_chunk)| {
135                    let start_row = chunk_idx * chunk_size;
136                    let end_row = (start_row + y_chunk.len()).min(row_count);
137
138                    for i in start_row..end_row {
139                        let row_idx = i - start_row;
140                        let mut sum = T::zero();
141
142                        let row_start = major_offsets[i];
143                        let row_end = major_offsets[i + 1];
144
145                        let mut j = row_start;
146
147                        while j + 4 <= row_end {
148                            for offset in 0..4 {
149                                let idx = j + offset;
150                                let col = minor_indices[idx];
151                                if let Some(masked_col) = valid_indices[col] {
152                                    sum += values[idx] * x[masked_col];
153                                }
154                            }
155                            j += 4;
156                        }
157
158                        while j < row_end {
159                            let col = minor_indices[j];
160                            if let Some(masked_col) = valid_indices[col] {
161                                sum += values[j] * x[masked_col];
162                            }
163                            j += 1;
164                        }
165
166                        y_chunk[row_idx] = sum;
167                    }
168                });
169        } else {
170            // A^T * x calculation
171            let nrows = self.matrix.nrows();
172            let chunk_size = determine_chunk_size(nrows);
173
174            // Process in parallel chunks
175            let results: Vec<Vec<T>> = (0..((nrows + chunk_size - 1) / chunk_size))
176                .into_par_iter()
177                .map(|chunk_idx| {
178                    let start = chunk_idx * chunk_size;
179                    let end = (start + chunk_size).min(nrows);
180
181                    let mut local_y = vec![T::zero(); y.len()];
182                    for i in start..end {
183                        let row_val = x[i];
184                        for j in major_offsets[i]..major_offsets[i + 1] {
185                            let col = minor_indices[j];
186                            if let Some(masked_col) = self.original_to_masked[col] {
187                                local_y[masked_col] += values[j] * row_val;
188                            }
189                        }
190                    }
191                    local_y
192                })
193                .collect();
194
195            // Combine results
196            for local_y in results {
197                for (idx, val) in local_y.iter().enumerate() {
198                    if !val.is_zero() {
199                        y[idx] += *val;
200                    }
201                }
202            }
203        }
204    }
205
206    fn compute_column_means(&self) -> Vec<T> {
207        let rows = self.nrows();
208        let masked_cols = self.ncols();
209        let row_count_recip = T::one() / T::from(rows).unwrap();
210
211        let mut col_sums = vec![T::zero(); masked_cols];
212        let (row_offsets, col_indices, values) = self.matrix.csr_data();
213        
214        for i in 0..rows {
215            for j in row_offsets[i]..row_offsets[i + 1] {
216                let original_col = col_indices[j];
217                if let Some(masked_col) = self.original_to_masked[original_col] {
218                    col_sums[masked_col] += values[j];
219                }
220            }
221        }
222
223        // Convert to means
224        for j in 0..masked_cols {
225            col_sums[j] *= row_count_recip;
226        }
227
228        col_sums
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235    use crate::SMat;
236    use nalgebra_sparse::{coo::CooMatrix, csr::CsrMatrix};
237    use rand::rngs::StdRng;
238    use rand::{Rng, SeedableRng};
239
240    #[test]
241    fn test_masked_matrix() {
242        // Create a test matrix
243        let mut coo = CooMatrix::<f64>::new(3, 5);
244        coo.push(0, 0, 1.0);
245        coo.push(0, 2, 2.0);
246        coo.push(0, 4, 3.0);
247        coo.push(1, 1, 4.0);
248        coo.push(1, 3, 5.0);
249        coo.push(2, 0, 6.0);
250        coo.push(2, 2, 7.0);
251        coo.push(2, 4, 8.0);
252
253        let csr = CsrMatrix::from(&coo);
254
255        // Create a masked matrix with columns 0, 2, 4
256        let columns = vec![0, 2, 4];
257        let masked = MaskedCSRMatrix::with_columns(&csr, &columns);
258
259        // Check dimensions
260        assert_eq!(masked.nrows(), 3);
261        assert_eq!(masked.ncols(), 3);
262        assert_eq!(masked.nnz(), 6); // Only entries in the selected columns
263
264        // Test SVD on the masked matrix
265        let svd_result = crate::lanczos::svd(&masked);
266        assert!(svd_result.is_ok());
267    }
268
269    #[test]
270    fn test_masked_vs_physical_subset() {
271        // Create a fixed seed for reproducible tests
272        let mut rng = StdRng::seed_from_u64(42);
273
274        // Generate a random matrix (5x8)
275        let nrows = 14;
276        let ncols = 10;
277        let nnz = 40; // Number of non-zero elements
278
279        let mut coo = CooMatrix::<f64>::new(nrows, ncols);
280
281        // Fill with random non-zero values
282        for _ in 0..nnz {
283            let row = rng.gen_range(0..nrows);
284            let col = rng.gen_range(0..ncols);
285            let val = rng.gen_range(0.1..10.0);
286
287            // Note: CooMatrix will overwrite if the position already has a value
288            coo.push(row, col, val);
289        }
290
291        // Convert to CSR which is what our masked implementation uses
292        let csr = CsrMatrix::from(&coo);
293
294        // Select a subset of columns (e.g., columns 1, 3, 5, 7)
295        let selected_columns = vec![1, 3, 5, 7];
296
297        // Create the masked matrix view
298        let masked_matrix = MaskedCSRMatrix::with_columns(&csr, &selected_columns);
299
300        // Create a physical copy with just those columns
301        let mut physical_subset = CooMatrix::<f64>::new(nrows, selected_columns.len());
302
303        // Map original column indices to new column indices
304        let col_map: std::collections::HashMap<usize, usize> = selected_columns
305            .iter()
306            .enumerate()
307            .map(|(new_idx, &old_idx)| (old_idx, new_idx))
308            .collect();
309
310        // Copy the values for the selected columns
311        for (row, col, val) in coo.triplet_iter() {
312            if let Some(&new_col) = col_map.get(&col) {
313                physical_subset.push(row, new_col, *val);
314            }
315        }
316
317        // Convert to CSR for SVD
318        let physical_csr = CsrMatrix::from(&physical_subset);
319
320        // Compare dimensions and nnz
321        assert_eq!(masked_matrix.nrows(), physical_csr.nrows());
322        assert_eq!(masked_matrix.ncols(), physical_csr.ncols());
323        assert_eq!(masked_matrix.nnz(), physical_csr.nnz());
324
325        // Perform SVD on both
326        let svd_masked = crate::lanczos::svd(&masked_matrix).unwrap();
327        let svd_physical = crate::lanczos::svd(&physical_csr).unwrap();
328
329        // Compare SVD results - they should be very close but not exactly the same
330        // due to potential differences in numerical computation
331
332        // Check dimension (rank)
333        assert_eq!(svd_masked.d, svd_physical.d);
334
335        // Basic tolerance for floating point comparisons
336        let epsilon = 1e-10;
337
338        // Check singular values (may be in different order, so we sort them)
339        let mut masked_s = svd_masked.s.to_vec();
340        let mut physical_s = svd_physical.s.to_vec();
341        masked_s.sort_by(|a, b| b.partial_cmp(a).unwrap()); // Sort in descending order
342        physical_s.sort_by(|a, b| b.partial_cmp(a).unwrap());
343
344        for (m, p) in masked_s.iter().zip(physical_s.iter()) {
345            assert!(
346                (m - p).abs() < epsilon,
347                "Singular values differ: {} vs {}",
348                m,
349                p
350            );
351        }
352
353        // Note: Comparing singular vectors is more complex due to potential sign flips
354        // and different ordering, so we'll skip that level of detailed comparison
355    }
356}