1use crate::{SMat, SvdFloat};
2use nalgebra_sparse::CsrMatrix;
3use num_traits::Float;
4use std::ops::AddAssign;
5
6pub struct MaskedCSRMatrix<'a, T: Float> {
7 matrix: &'a CsrMatrix<T>,
8 column_mask: Vec<bool>,
9 masked_to_original: Vec<usize>,
10 original_to_masked: Vec<Option<usize>>,
11}
12
13impl<'a, T: Float> MaskedCSRMatrix<'a, T> {
14 pub fn new(matrix: &'a CsrMatrix<T>, column_mask: Vec<bool>) -> Self {
15 assert_eq!(
16 column_mask.len(),
17 matrix.ncols(),
18 "Column mask must have the same length as the number of columns in the matrix"
19 );
20
21 let mut masked_to_original = Vec::new();
22 let mut original_to_masked = vec![None; column_mask.len()];
23 let mut masked_index = 0;
24
25 for (i, &is_included) in column_mask.iter().enumerate() {
26 if is_included {
27 masked_to_original.push(i);
28 original_to_masked[i] = Some(masked_index);
29 masked_index += 1;
30 }
31 }
32
33 Self {
34 matrix,
35 column_mask,
36 masked_to_original,
37 original_to_masked,
38 }
39 }
40
41 pub fn with_columns(matrix: &'a CsrMatrix<T>, columns: &[usize]) -> Self {
42 let mut mask = vec![false; matrix.ncols()];
43 for &col in columns {
44 assert!(col < matrix.ncols(), "Column index out of bounds");
45 mask[col] = true;
46 }
47 Self::new(matrix, mask)
48 }
49
50 pub fn uses_all_columns(&self) -> bool {
52 self.masked_to_original.len() == self.matrix.ncols() && self.column_mask.iter().all(|&x| x)
53 }
54
55 pub fn ensure_identical_results_mode(&self) -> bool {
57 let is_small_matrix = self.matrix.nrows() <= 5 && self.matrix.ncols() <= 5;
59 is_small_matrix && self.uses_all_columns()
60 }
61}
62
63impl<'a, T: Float + AddAssign> SMat<T> for MaskedCSRMatrix<'a, T> {
64 fn nrows(&self) -> usize {
65 self.matrix.nrows()
66 }
67
68 fn ncols(&self) -> usize {
69 self.masked_to_original.len()
70 }
71
72 fn nnz(&self) -> usize {
73 let (major_offsets, minor_indices, _) = self.matrix.csr_data();
74 let mut count = 0;
75
76 for i in 0..self.matrix.nrows() {
77 for j in major_offsets[i]..major_offsets[i + 1] {
78 let col = minor_indices[j]; if self.column_mask[col] {
80 count += 1;
81 }
82 }
83 }
84 count
85 }
86
87 fn svd_opa(&self, x: &[T], y: &mut [T], transposed: bool) {
88 let nrows = if transposed {
89 self.ncols()
90 } else {
91 self.nrows()
92 };
93 let ncols = if transposed {
94 self.nrows()
95 } else {
96 self.ncols()
97 };
98
99 assert_eq!(
100 x.len(),
101 ncols,
102 "svd_opa: x must be A.ncols() in length, x = {}, A.ncols = {}",
103 x.len(),
104 ncols
105 );
106 assert_eq!(
107 y.len(),
108 nrows,
109 "svd_opa: y must be A.nrows() in length, y = {}, A.nrows = {}",
110 y.len(),
111 nrows
112 );
113
114 let (major_offsets, minor_indices, values) = self.matrix.csr_data();
115
116 y.fill(T::zero());
117
118 let high_precision_mode = self.ensure_identical_results_mode();
119
120 if !transposed {
121 if high_precision_mode && self.uses_all_columns() {
122 for i in 0..self.matrix.nrows() {
125 let mut sum = T::zero();
126 for j in major_offsets[i]..major_offsets[i + 1] {
127 let col = minor_indices[j];
128 let masked_col = self.original_to_masked[col].unwrap();
130 sum = sum + (values[j] * x[masked_col]);
131 }
132 y[i] = sum;
133 }
134 } else {
135 for i in 0..self.matrix.nrows() {
137 for j in major_offsets[i]..major_offsets[i + 1] {
138 let col = minor_indices[j];
139 if let Some(masked_col) = self.original_to_masked[col] {
140 y[i] += values[j] * x[masked_col];
141 }
142 }
143 }
144 }
145 } else {
146 if high_precision_mode && self.uses_all_columns() {
148 for yval in y.iter_mut() {
150 *yval = T::zero();
151 }
152
153 for i in 0..self.matrix.nrows() {
155 let row_val = x[i];
156 for j in major_offsets[i]..major_offsets[i + 1] {
157 let col = minor_indices[j];
158 let masked_col = self.original_to_masked[col].unwrap();
159 y[masked_col] = y[masked_col] + (values[j] * row_val);
160 }
161 }
162 } else {
163 for i in 0..self.matrix.nrows() {
165 let row_val = x[i];
166 for j in major_offsets[i]..major_offsets[i + 1] {
167 let col = minor_indices[j];
168 if let Some(masked_col) = self.original_to_masked[col] {
169 y[masked_col] += values[j] * row_val;
170 }
171 }
172 }
173 }
174 }
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181 use crate::{svd, SMat};
182 use nalgebra_sparse::{coo::CooMatrix, csr::CsrMatrix};
183 use rand::rngs::StdRng;
184 use rand::{Rng, SeedableRng};
185
186 #[test]
187 fn test_masked_matrix() {
188 let mut coo = CooMatrix::<f64>::new(3, 5);
190 coo.push(0, 0, 1.0);
191 coo.push(0, 2, 2.0);
192 coo.push(0, 4, 3.0);
193 coo.push(1, 1, 4.0);
194 coo.push(1, 3, 5.0);
195 coo.push(2, 0, 6.0);
196 coo.push(2, 2, 7.0);
197 coo.push(2, 4, 8.0);
198
199 let csr = CsrMatrix::from(&coo);
200
201 let columns = vec![0, 2, 4];
203 let masked = MaskedCSRMatrix::with_columns(&csr, &columns);
204
205 assert_eq!(masked.nrows(), 3);
207 assert_eq!(masked.ncols(), 3);
208 assert_eq!(masked.nnz(), 6); let svd_result = svd(&masked);
212 assert!(svd_result.is_ok());
213 }
214
215 #[test]
216 fn test_masked_vs_physical_subset() {
217 let mut rng = StdRng::seed_from_u64(42);
219
220 let nrows = 14;
222 let ncols = 10;
223 let nnz = 40; let mut coo = CooMatrix::<f64>::new(nrows, ncols);
226
227 for _ in 0..nnz {
229 let row = rng.gen_range(0..nrows);
230 let col = rng.gen_range(0..ncols);
231 let val = rng.gen_range(0.1..10.0);
232
233 coo.push(row, col, val);
235 }
236
237 let csr = CsrMatrix::from(&coo);
239
240 let selected_columns = vec![1, 3, 5, 7];
242
243 let masked_matrix = MaskedCSRMatrix::with_columns(&csr, &selected_columns);
245
246 let mut physical_subset = CooMatrix::<f64>::new(nrows, selected_columns.len());
248
249 let col_map: std::collections::HashMap<usize, usize> = selected_columns
251 .iter()
252 .enumerate()
253 .map(|(new_idx, &old_idx)| (old_idx, new_idx))
254 .collect();
255
256 for (row, col, val) in coo.triplet_iter() {
258 if let Some(&new_col) = col_map.get(&col) {
259 physical_subset.push(row, new_col, *val);
260 }
261 }
262
263 let physical_csr = CsrMatrix::from(&physical_subset);
265
266 assert_eq!(masked_matrix.nrows(), physical_csr.nrows());
268 assert_eq!(masked_matrix.ncols(), physical_csr.ncols());
269 assert_eq!(masked_matrix.nnz(), physical_csr.nnz());
270
271 let svd_masked = svd(&masked_matrix).unwrap();
273 let svd_physical = svd(&physical_csr).unwrap();
274
275 assert_eq!(svd_masked.d, svd_physical.d);
280
281 let epsilon = 1e-10;
283
284 let mut masked_s = svd_masked.s.to_vec();
286 let mut physical_s = svd_physical.s.to_vec();
287 masked_s.sort_by(|a, b| b.partial_cmp(a).unwrap()); physical_s.sort_by(|a, b| b.partial_cmp(a).unwrap());
289
290 for (m, p) in masked_s.iter().zip(physical_s.iter()) {
291 assert!(
292 (m - p).abs() < epsilon,
293 "Singular values differ: {} vs {}",
294 m,
295 p
296 );
297 }
298
299 }
302}