1use crate::utils::determine_chunk_size;
2use crate::{SMat, SvdFloat};
3use nalgebra_sparse::CsrMatrix;
4use num_traits::Float;
5use rayon::iter::IndexedParallelIterator;
6use rayon::iter::ParallelIterator;
7use rayon::prelude::{IntoParallelIterator, ParallelBridge, ParallelSliceMut};
8use std::ops::AddAssign;
9
10pub struct MaskedCSRMatrix<'a, T: Float> {
11 matrix: &'a CsrMatrix<T>,
12 column_mask: Vec<bool>,
13 masked_to_original: Vec<usize>,
14 original_to_masked: Vec<Option<usize>>,
15}
16
17impl<'a, T: Float> MaskedCSRMatrix<'a, T> {
18 pub fn new(matrix: &'a CsrMatrix<T>, column_mask: Vec<bool>) -> Self {
19 assert_eq!(
20 column_mask.len(),
21 matrix.ncols(),
22 "Column mask must have the same length as the number of columns in the matrix"
23 );
24
25 let mut masked_to_original = Vec::new();
26 let mut original_to_masked = vec![None; column_mask.len()];
27 let mut masked_index = 0;
28
29 for (i, &is_included) in column_mask.iter().enumerate() {
30 if is_included {
31 masked_to_original.push(i);
32 original_to_masked[i] = Some(masked_index);
33 masked_index += 1;
34 }
35 }
36
37 Self {
38 matrix,
39 column_mask,
40 masked_to_original,
41 original_to_masked,
42 }
43 }
44
45 pub fn with_columns(matrix: &'a CsrMatrix<T>, columns: &[usize]) -> Self {
46 let mut mask = vec![false; matrix.ncols()];
47 for &col in columns {
48 assert!(col < matrix.ncols(), "Column index out of bounds");
49 mask[col] = true;
50 }
51 Self::new(matrix, mask)
52 }
53
54 pub fn uses_all_columns(&self) -> bool {
55 self.masked_to_original.len() == self.matrix.ncols() && self.column_mask.iter().all(|&x| x)
56 }
57
58 pub fn ensure_identical_results_mode(&self) -> bool {
59 let is_small_matrix = self.matrix.nrows() <= 5 && self.matrix.ncols() <= 5;
61 is_small_matrix && self.uses_all_columns()
62 }
63}
64
65impl<'a, T: Float + AddAssign + Sync + Send + std::ops::MulAssign> SMat<T> for MaskedCSRMatrix<'a, T> {
66 fn nrows(&self) -> usize {
67 self.matrix.nrows()
68 }
69
70 fn ncols(&self) -> usize {
71 self.masked_to_original.len()
72 }
73
74 fn nnz(&self) -> usize {
75 let (major_offsets, minor_indices, _) = self.matrix.csr_data();
76 let mut count = 0;
77
78 for i in 0..self.matrix.nrows() {
79 for j in major_offsets[i]..major_offsets[i + 1] {
80 let col = minor_indices[j];
81 if self.column_mask[col] {
82 count += 1;
83 }
84 }
85 }
86 count
87 }
88
89 fn svd_opa(&self, x: &[T], y: &mut [T], transposed: bool) {
90 let nrows = if transposed {
91 self.ncols()
92 } else {
93 self.nrows()
94 };
95 let ncols = if transposed {
96 self.nrows()
97 } else {
98 self.ncols()
99 };
100
101 assert_eq!(
102 x.len(),
103 ncols,
104 "svd_opa: x must be A.ncols() in length, x = {}, A.ncols = {}",
105 x.len(),
106 ncols
107 );
108 assert_eq!(
109 y.len(),
110 nrows,
111 "svd_opa: y must be A.nrows() in length, y = {}, A.nrows = {}",
112 y.len(),
113 nrows
114 );
115
116 let (major_offsets, minor_indices, values) = self.matrix.csr_data();
117
118 y.fill(T::zero());
119
120 if !transposed {
121 let row_count = self.matrix.nrows();
123 let (major_offsets, minor_indices, values) = self.matrix.csr_data();
124
125 let chunk_size = std::cmp::max(16, row_count / (rayon::current_num_threads() * 2));
126
127 let mut valid_indices = Vec::with_capacity(self.matrix.ncols());
128 for col in 0..self.matrix.ncols() {
129 valid_indices.push(self.original_to_masked[col]);
130 }
131
132 y.par_chunks_mut(chunk_size)
133 .enumerate()
134 .for_each(|(chunk_idx, y_chunk)| {
135 let start_row = chunk_idx * chunk_size;
136 let end_row = (start_row + y_chunk.len()).min(row_count);
137
138 for i in start_row..end_row {
139 let row_idx = i - start_row;
140 let mut sum = T::zero();
141
142 let row_start = major_offsets[i];
143 let row_end = major_offsets[i + 1];
144
145 let mut j = row_start;
146
147 while j + 4 <= row_end {
148 for offset in 0..4 {
149 let idx = j + offset;
150 let col = minor_indices[idx];
151 if let Some(masked_col) = valid_indices[col] {
152 sum += values[idx] * x[masked_col];
153 }
154 }
155 j += 4;
156 }
157
158 while j < row_end {
159 let col = minor_indices[j];
160 if let Some(masked_col) = valid_indices[col] {
161 sum += values[j] * x[masked_col];
162 }
163 j += 1;
164 }
165
166 y_chunk[row_idx] = sum;
167 }
168 });
169 } else {
170 let nrows = self.matrix.nrows();
172 let chunk_size = determine_chunk_size(nrows);
173
174 let results: Vec<Vec<T>> = (0..((nrows + chunk_size - 1) / chunk_size))
176 .into_par_iter()
177 .map(|chunk_idx| {
178 let start = chunk_idx * chunk_size;
179 let end = (start + chunk_size).min(nrows);
180
181 let mut local_y = vec![T::zero(); y.len()];
182 for i in start..end {
183 let row_val = x[i];
184 for j in major_offsets[i]..major_offsets[i + 1] {
185 let col = minor_indices[j];
186 if let Some(masked_col) = self.original_to_masked[col] {
187 local_y[masked_col] += values[j] * row_val;
188 }
189 }
190 }
191 local_y
192 })
193 .collect();
194
195 for local_y in results {
197 for (idx, val) in local_y.iter().enumerate() {
198 if !val.is_zero() {
199 y[idx] += *val;
200 }
201 }
202 }
203 }
204 }
205
206 fn compute_column_means(&self) -> Vec<T> {
207 let rows = self.nrows();
208 let masked_cols = self.ncols();
209 let row_count_recip = T::one() / T::from(rows).unwrap();
210
211 let mut col_sums = vec![T::zero(); masked_cols];
212 let (row_offsets, col_indices, values) = self.matrix.csr_data();
213
214 for i in 0..rows {
215 for j in row_offsets[i]..row_offsets[i + 1] {
216 let original_col = col_indices[j];
217 if let Some(masked_col) = self.original_to_masked[original_col] {
218 col_sums[masked_col] += values[j];
219 }
220 }
221 }
222
223 for j in 0..masked_cols {
225 col_sums[j] *= row_count_recip;
226 }
227
228 col_sums
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235 use crate::SMat;
236 use nalgebra_sparse::{coo::CooMatrix, csr::CsrMatrix};
237 use rand::rngs::StdRng;
238 use rand::{Rng, SeedableRng};
239
240 #[test]
241 fn test_masked_matrix() {
242 let mut coo = CooMatrix::<f64>::new(3, 5);
244 coo.push(0, 0, 1.0);
245 coo.push(0, 2, 2.0);
246 coo.push(0, 4, 3.0);
247 coo.push(1, 1, 4.0);
248 coo.push(1, 3, 5.0);
249 coo.push(2, 0, 6.0);
250 coo.push(2, 2, 7.0);
251 coo.push(2, 4, 8.0);
252
253 let csr = CsrMatrix::from(&coo);
254
255 let columns = vec![0, 2, 4];
257 let masked = MaskedCSRMatrix::with_columns(&csr, &columns);
258
259 assert_eq!(masked.nrows(), 3);
261 assert_eq!(masked.ncols(), 3);
262 assert_eq!(masked.nnz(), 6); let svd_result = crate::lanczos::svd(&masked);
266 assert!(svd_result.is_ok());
267 }
268
269 #[test]
270 fn test_masked_vs_physical_subset() {
271 let mut rng = StdRng::seed_from_u64(42);
273
274 let nrows = 14;
276 let ncols = 10;
277 let nnz = 40; let mut coo = CooMatrix::<f64>::new(nrows, ncols);
280
281 for _ in 0..nnz {
283 let row = rng.gen_range(0..nrows);
284 let col = rng.gen_range(0..ncols);
285 let val = rng.gen_range(0.1..10.0);
286
287 coo.push(row, col, val);
289 }
290
291 let csr = CsrMatrix::from(&coo);
293
294 let selected_columns = vec![1, 3, 5, 7];
296
297 let masked_matrix = MaskedCSRMatrix::with_columns(&csr, &selected_columns);
299
300 let mut physical_subset = CooMatrix::<f64>::new(nrows, selected_columns.len());
302
303 let col_map: std::collections::HashMap<usize, usize> = selected_columns
305 .iter()
306 .enumerate()
307 .map(|(new_idx, &old_idx)| (old_idx, new_idx))
308 .collect();
309
310 for (row, col, val) in coo.triplet_iter() {
312 if let Some(&new_col) = col_map.get(&col) {
313 physical_subset.push(row, new_col, *val);
314 }
315 }
316
317 let physical_csr = CsrMatrix::from(&physical_subset);
319
320 assert_eq!(masked_matrix.nrows(), physical_csr.nrows());
322 assert_eq!(masked_matrix.ncols(), physical_csr.ncols());
323 assert_eq!(masked_matrix.nnz(), physical_csr.nnz());
324
325 let svd_masked = crate::lanczos::svd(&masked_matrix).unwrap();
327 let svd_physical = crate::lanczos::svd(&physical_csr).unwrap();
328
329 assert_eq!(svd_masked.d, svd_physical.d);
334
335 let epsilon = 1e-10;
337
338 let mut masked_s = svd_masked.s.to_vec();
340 let mut physical_s = svd_physical.s.to_vec();
341 masked_s.sort_by(|a, b| b.partial_cmp(a).unwrap()); physical_s.sort_by(|a, b| b.partial_cmp(a).unwrap());
343
344 for (m, p) in masked_s.iter().zip(physical_s.iter()) {
345 assert!(
346 (m - p).abs() < epsilon,
347 "Singular values differ: {} vs {}",
348 m,
349 p
350 );
351 }
352
353 }
356}