1use crate::utils::determine_chunk_size;
2use crate::{SMat, SvdFloat};
3use nalgebra_sparse::CsrMatrix;
4use num_traits::Float;
5use rayon::iter::ParallelIterator;
6use rayon::prelude::{IntoParallelIterator, ParallelBridge};
7use std::ops::AddAssign;
8
9pub struct MaskedCSRMatrix<'a, T: Float> {
10 matrix: &'a CsrMatrix<T>,
11 column_mask: Vec<bool>,
12 masked_to_original: Vec<usize>,
13 original_to_masked: Vec<Option<usize>>,
14}
15
16impl<'a, T: Float> MaskedCSRMatrix<'a, T> {
17 pub fn new(matrix: &'a CsrMatrix<T>, column_mask: Vec<bool>) -> Self {
18 assert_eq!(
19 column_mask.len(),
20 matrix.ncols(),
21 "Column mask must have the same length as the number of columns in the matrix"
22 );
23
24 let mut masked_to_original = Vec::new();
25 let mut original_to_masked = vec![None; column_mask.len()];
26 let mut masked_index = 0;
27
28 for (i, &is_included) in column_mask.iter().enumerate() {
29 if is_included {
30 masked_to_original.push(i);
31 original_to_masked[i] = Some(masked_index);
32 masked_index += 1;
33 }
34 }
35
36 Self {
37 matrix,
38 column_mask,
39 masked_to_original,
40 original_to_masked,
41 }
42 }
43
44 pub fn with_columns(matrix: &'a CsrMatrix<T>, columns: &[usize]) -> Self {
45 let mut mask = vec![false; matrix.ncols()];
46 for &col in columns {
47 assert!(col < matrix.ncols(), "Column index out of bounds");
48 mask[col] = true;
49 }
50 Self::new(matrix, mask)
51 }
52
53 pub fn uses_all_columns(&self) -> bool {
54 self.masked_to_original.len() == self.matrix.ncols() && self.column_mask.iter().all(|&x| x)
55 }
56
57 pub fn ensure_identical_results_mode(&self) -> bool {
58 let is_small_matrix = self.matrix.nrows() <= 5 && self.matrix.ncols() <= 5;
60 is_small_matrix && self.uses_all_columns()
61 }
62}
63
64impl<'a, T: Float + AddAssign + Sync + Send> SMat<T> for MaskedCSRMatrix<'a, T> {
65 fn nrows(&self) -> usize {
66 self.matrix.nrows()
67 }
68
69 fn ncols(&self) -> usize {
70 self.masked_to_original.len()
71 }
72
73 fn nnz(&self) -> usize {
74 let (major_offsets, minor_indices, _) = self.matrix.csr_data();
75 let mut count = 0;
76
77 for i in 0..self.matrix.nrows() {
78 for j in major_offsets[i]..major_offsets[i + 1] {
79 let col = minor_indices[j];
80 if self.column_mask[col] {
81 count += 1;
82 }
83 }
84 }
85 count
86 }
87
88 fn svd_opa(&self, x: &[T], y: &mut [T], transposed: bool) {
89 let nrows = if transposed {
91 self.ncols()
92 } else {
93 self.nrows()
94 };
95 let ncols = if transposed {
96 self.nrows()
97 } else {
98 self.ncols()
99 };
100
101 assert_eq!(
102 x.len(),
103 ncols,
104 "svd_opa: x must be A.ncols() in length, x = {}, A.ncols = {}",
105 x.len(),
106 ncols
107 );
108 assert_eq!(
109 y.len(),
110 nrows,
111 "svd_opa: y must be A.nrows() in length, y = {}, A.nrows = {}",
112 y.len(),
113 nrows
114 );
115
116 let (major_offsets, minor_indices, values) = self.matrix.csr_data();
117
118 y.fill(T::zero());
119
120 let high_precision_mode = self.ensure_identical_results_mode();
121
122 if !transposed {
123 if high_precision_mode && self.uses_all_columns() {
124 for i in 0..self.matrix.nrows() {
127 let mut sum = T::zero();
128 for j in major_offsets[i]..major_offsets[i + 1] {
129 let col = minor_indices[j];
130 let masked_col = self.original_to_masked[col].unwrap();
132 sum = sum + (values[j] * x[masked_col]);
133 }
134 y[i] = sum;
135 }
136 } else {
137 let chunk_size = determine_chunk_size(self.matrix.nrows());
138 y.chunks_mut(chunk_size).enumerate().par_bridge().for_each(
139 |(chunk_idx, y_chunk)| {
140 let start_row = chunk_idx * chunk_size;
141 let end_row = (start_row + y_chunk.len()).min(self.matrix.nrows());
142
143 for i in start_row..end_row {
144 let row_idx = i - start_row;
145 let mut sum = T::zero();
146
147 for j in major_offsets[i]..major_offsets[i + 1] {
148 let col = minor_indices[j];
149 if let Some(masked_col) = self.original_to_masked[col] {
150 sum += values[j] * x[masked_col];
151 };
152 }
153 y_chunk[row_idx] = sum;
154 }
155 },
156 );
157 }
158 } else {
159 if high_precision_mode && self.uses_all_columns() {
161 for yval in y.iter_mut() {
163 *yval = T::zero();
164 }
165
166 for i in 0..self.matrix.nrows() {
168 let row_val = x[i];
169 for j in major_offsets[i]..major_offsets[i + 1] {
170 let col = minor_indices[j];
171 let masked_col = self.original_to_masked[col].unwrap();
172 y[masked_col] = y[masked_col] + (values[j] * row_val);
173 }
174 }
175 } else {
176 let nrows = self.matrix.nrows();
177 let chunk_size = determine_chunk_size(nrows);
178 let num_chunks = (nrows + chunk_size - 1) / chunk_size;
179 let results: Vec<Vec<T>> = (0..chunk_size)
180 .into_par_iter()
181 .map(|chunk_idx| {
182 let start = chunk_idx * chunk_size;
183 let end = (start + chunk_size).min(nrows);
184
185 let mut local_y = vec![T::zero(); y.len()];
186 for i in start..end {
187 let row_val = x[i];
188 for j in major_offsets[i]..major_offsets[i + 1] {
189 let col = minor_indices[j];
190 if let Some(masked_col) = self.original_to_masked[col] {
191 local_y[masked_col] += values[j] * row_val;
192 }
193 }
194 }
195 local_y
196 })
197 .collect();
198
199 y.fill(T::zero());
200
201 for local_y in results {
202 for (idx, val) in local_y.iter().enumerate() {
203 if !val.is_zero() {
204 y[idx] += *val;
205 }
206 }
207 }
208 }
209 }
210 }
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216 use crate::{svd, SMat};
217 use nalgebra_sparse::{coo::CooMatrix, csr::CsrMatrix};
218 use rand::rngs::StdRng;
219 use rand::{Rng, SeedableRng};
220
221 #[test]
222 fn test_masked_matrix() {
223 let mut coo = CooMatrix::<f64>::new(3, 5);
225 coo.push(0, 0, 1.0);
226 coo.push(0, 2, 2.0);
227 coo.push(0, 4, 3.0);
228 coo.push(1, 1, 4.0);
229 coo.push(1, 3, 5.0);
230 coo.push(2, 0, 6.0);
231 coo.push(2, 2, 7.0);
232 coo.push(2, 4, 8.0);
233
234 let csr = CsrMatrix::from(&coo);
235
236 let columns = vec![0, 2, 4];
238 let masked = MaskedCSRMatrix::with_columns(&csr, &columns);
239
240 assert_eq!(masked.nrows(), 3);
242 assert_eq!(masked.ncols(), 3);
243 assert_eq!(masked.nnz(), 6); let svd_result = svd(&masked);
247 assert!(svd_result.is_ok());
248 }
249
250 #[test]
251 fn test_masked_vs_physical_subset() {
252 let mut rng = StdRng::seed_from_u64(42);
254
255 let nrows = 14;
257 let ncols = 10;
258 let nnz = 40; let mut coo = CooMatrix::<f64>::new(nrows, ncols);
261
262 for _ in 0..nnz {
264 let row = rng.gen_range(0..nrows);
265 let col = rng.gen_range(0..ncols);
266 let val = rng.gen_range(0.1..10.0);
267
268 coo.push(row, col, val);
270 }
271
272 let csr = CsrMatrix::from(&coo);
274
275 let selected_columns = vec![1, 3, 5, 7];
277
278 let masked_matrix = MaskedCSRMatrix::with_columns(&csr, &selected_columns);
280
281 let mut physical_subset = CooMatrix::<f64>::new(nrows, selected_columns.len());
283
284 let col_map: std::collections::HashMap<usize, usize> = selected_columns
286 .iter()
287 .enumerate()
288 .map(|(new_idx, &old_idx)| (old_idx, new_idx))
289 .collect();
290
291 for (row, col, val) in coo.triplet_iter() {
293 if let Some(&new_col) = col_map.get(&col) {
294 physical_subset.push(row, new_col, *val);
295 }
296 }
297
298 let physical_csr = CsrMatrix::from(&physical_subset);
300
301 assert_eq!(masked_matrix.nrows(), physical_csr.nrows());
303 assert_eq!(masked_matrix.ncols(), physical_csr.ncols());
304 assert_eq!(masked_matrix.nnz(), physical_csr.nnz());
305
306 let svd_masked = svd(&masked_matrix).unwrap();
308 let svd_physical = svd(&physical_csr).unwrap();
309
310 assert_eq!(svd_masked.d, svd_physical.d);
315
316 let epsilon = 1e-10;
318
319 let mut masked_s = svd_masked.s.to_vec();
321 let mut physical_s = svd_physical.s.to_vec();
322 masked_s.sort_by(|a, b| b.partial_cmp(a).unwrap()); physical_s.sort_by(|a, b| b.partial_cmp(a).unwrap());
324
325 for (m, p) in masked_s.iter().zip(physical_s.iter()) {
326 assert!(
327 (m - p).abs() < epsilon,
328 "Singular values differ: {} vs {}",
329 m,
330 p
331 );
332 }
333
334 }
337}