scirs2_transform/decomposition/
dictionary_learning.rs1use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Data, Ix2};
8use scirs2_core::numeric::{Float, NumCast};
9use scirs2_core::random::Rng;
10use scirs2_linalg::{svd, vector_norm};
11
12use crate::error::{Result, TransformError};
13
14#[derive(Debug, Clone)]
20pub struct DictionaryLearning {
21 n_components: usize,
23 alpha: f64,
25 max_iter: usize,
27 tol: f64,
29 transform_algorithm: String,
31 random_state: Option<u64>,
33 shuffle: bool,
35 dictionary: Option<Array2<f64>>,
37 n_iter: Option<usize>,
39}
40
41impl DictionaryLearning {
42 pub fn new(ncomponents: usize, alpha: f64) -> Self {
48 DictionaryLearning {
49 n_components: ncomponents,
50 alpha,
51 max_iter: 1000,
52 tol: 1e-4,
53 transform_algorithm: "omp".to_string(),
54 random_state: None,
55 shuffle: true,
56 dictionary: None,
57 n_iter: None,
58 }
59 }
60
61 pub fn with_max_iter(mut self, maxiter: usize) -> Self {
63 self.max_iter = maxiter;
64 self
65 }
66
67 pub fn with_tolerance(mut self, tol: f64) -> Self {
69 self.tol = tol;
70 self
71 }
72
73 pub fn with_transform_algorithm(mut self, algorithm: &str) -> Self {
75 self.transform_algorithm = algorithm.to_string();
76 self
77 }
78
79 pub fn with_random_state(mut self, seed: u64) -> Self {
81 self.random_state = Some(seed);
82 self
83 }
84
85 pub fn with_shuffle(mut self, shuffle: bool) -> Self {
87 self.shuffle = shuffle;
88 self
89 }
90
91 fn initialize_dictionary(&self, x: &Array2<f64>) -> Array2<f64> {
93 let n_features = x.shape()[1];
94 let n_samples = x.shape()[0];
95
96 let mut rng = scirs2_core::random::rng();
97
98 let mut dictionary = Array2::zeros((self.n_components, n_features));
99
100 for i in 0..self.n_components {
102 let idx = rng.gen_range(0..n_samples);
103 dictionary.row_mut(i).assign(&x.row(idx));
104
105 let norm = vector_norm(&dictionary.row(i).view(), 2).unwrap_or(0.0);
107 if norm > 1e-10 {
108 dictionary.row_mut(i).mapv_inplace(|x| x / norm);
109 }
110 }
111
112 dictionary
113 }
114
115 fn omp_sparse_code(
117 &self,
118 x: &Array1<f64>,
119 dictionary: &Array2<f64>,
120 n_nonzero_coefs: usize,
121 ) -> Array1<f64> {
122 let n_atoms = dictionary.shape()[0];
123 let mut residual = x.clone();
124 let mut sparse_code = Array1::zeros(n_atoms);
125 let mut selected_atoms = Vec::new();
126
127 for _ in 0..n_nonzero_coefs.min(n_atoms) {
128 let mut best_atom = 0;
130 let mut best_correlation = 0.0;
131
132 for j in 0..n_atoms {
133 if selected_atoms.contains(&j) {
134 continue;
135 }
136
137 let correlation = residual.dot(&dictionary.row(j)).abs();
138 if correlation > best_correlation {
139 best_correlation = correlation;
140 best_atom = j;
141 }
142 }
143
144 if best_correlation < 1e-10 {
145 break;
146 }
147
148 selected_atoms.push(best_atom);
149
150 if selected_atoms.len() == 1 {
152 let atom = dictionary.row(best_atom);
154 let coef = x.dot(&atom) / atom.dot(&atom);
155 sparse_code[best_atom] = coef;
156 residual = x - &(atom.to_owned() * coef);
157 } else {
158 let n_selected = selected_atoms.len();
160 let mut sub_dictionary = Array2::zeros((n_selected, dictionary.shape()[1]));
161
162 for (i, &atom_idx) in selected_atoms.iter().enumerate() {
163 sub_dictionary.row_mut(i).assign(&dictionary.row(atom_idx));
164 }
165
166 let gram = sub_dictionary.dot(&sub_dictionary.t());
168 let proj = sub_dictionary.dot(&x.view());
169
170 let alpha = self.solve_small_least_squares(&gram, &proj);
172
173 sparse_code.fill(0.0);
175 for (i, &atom_idx) in selected_atoms.iter().enumerate() {
176 sparse_code[atom_idx] = alpha[i];
177 }
178
179 residual = x - &dictionary.t().dot(&sparse_code);
180 }
181 }
182
183 sparse_code
184 }
185
186 fn solve_small_least_squares(&self, a: &Array2<f64>, b: &Array1<f64>) -> Array1<f64> {
188 let n = a.shape()[0];
189 let mut result = b.clone();
190
191 let mut lu = a.clone();
193 let mut perm = (0..n).collect::<Vec<_>>();
194
195 for k in 0..n - 1 {
197 let mut max_idx = k;
199 let mut max_val = lu[[k, k]].abs();
200 for i in k + 1..n {
201 if lu[[i, k]].abs() > max_val {
202 max_val = lu[[i, k]].abs();
203 max_idx = i;
204 }
205 }
206
207 if max_idx != k {
209 perm.swap(k, max_idx);
210 for j in 0..n {
211 let tmp = lu[[k, j]];
212 lu[[k, j]] = lu[[max_idx, j]];
213 lu[[max_idx, j]] = tmp;
214 }
215 let tmp = result[k];
216 result[k] = result[max_idx];
217 result[max_idx] = tmp;
218 }
219
220 for i in k + 1..n {
222 let factor = lu[[i, k]] / lu[[k, k]];
223 for j in k + 1..n {
224 lu[[i, j]] -= factor * lu[[k, j]];
225 }
226 result[i] -= factor * result[k];
227 }
228 }
229
230 for i in (0..n).rev() {
232 for j in i + 1..n {
233 result[i] -= lu[[i, j]] * result[j];
234 }
235 result[i] /= lu[[i, i]];
236 }
237
238 result
239 }
240
241 fn sparse_code_step(&self, x: &Array2<f64>, dictionary: &Array2<f64>) -> Array2<f64> {
243 let n_samples = x.shape()[0];
244 let n_atoms = dictionary.shape()[0];
245 let mut codes = Array2::zeros((n_samples, n_atoms));
246
247 let n_nonzero_coefs = (self.alpha * n_atoms as f64).ceil() as usize;
249
250 for i in 0..n_samples {
252 let sparse_code =
253 self.omp_sparse_code(&x.row(i).to_owned(), dictionary, n_nonzero_coefs);
254 codes.row_mut(i).assign(&sparse_code);
255 }
256
257 codes
258 }
259
260 fn dictionary_update_step(
262 &self,
263 x: &Array2<f64>,
264 sparse_codes: &mut Array2<f64>,
265 dictionary: &mut Array2<f64>,
266 ) {
267 let n_atoms = dictionary.shape()[0];
268 let n_features = dictionary.shape()[1];
269
270 for k in 0..n_atoms {
271 let mut using_samples = Vec::new();
273 for i in 0..sparse_codes.shape()[0] {
274 if sparse_codes[[i, k]].abs() > 1e-10 {
275 using_samples.push(i);
276 }
277 }
278
279 if using_samples.is_empty() {
280 continue;
281 }
282
283 let mut residual = Array2::zeros((using_samples.len(), n_features));
285 for (idx, &i) in using_samples.iter().enumerate() {
286 let mut r = x.row(i).to_owned();
287 for j in 0..n_atoms {
288 if j != k {
289 r = r - dictionary.row(j).to_owned() * sparse_codes[[i, j]];
290 }
291 }
292 residual.row_mut(idx).assign(&r);
293 }
294
295 if residual.shape()[0] > 0 {
297 match svd::<f64>(&residual.view(), false, Some(1)) {
298 Ok((u, s, vt)) => {
299 dictionary.row_mut(k).assign(&vt.row(0));
301
302 for (idx, &i) in using_samples.iter().enumerate() {
304 sparse_codes[[i, k]] = u[[idx, 0]] * s[0];
305 }
306 }
307 Err(_) => {
308 let norm = vector_norm(&dictionary.row(k).view(), 2).unwrap_or(0.0);
310 if norm > 1e-10 {
311 dictionary.row_mut(k).mapv_inplace(|x| x / norm);
312 }
313 }
314 }
315 }
316 }
317 }
318
319 pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
327 where
328 S: Data,
329 S::Elem: Float + NumCast,
330 {
331 let x_f64 = x.mapv(|v| NumCast::from(v).unwrap_or(0.0));
332 let _n_samples = x_f64.shape()[0];
333 let n_features = x_f64.shape()[1];
334
335 if self.n_components > n_features {
336 return Err(TransformError::InvalidInput(format!(
337 "n_components={} must be <= n_features={}",
338 self.n_components, n_features
339 )));
340 }
341
342 let mut dictionary = self.initialize_dictionary(&x_f64);
344 let mut prev_error = f64::INFINITY;
345 let mut n_iter = 0;
346
347 for iter in 0..self.max_iter {
349 let mut sparse_codes = self.sparse_code_step(&x_f64, &dictionary);
351
352 self.dictionary_update_step(&x_f64, &mut sparse_codes, &mut dictionary);
354
355 let reconstruction = sparse_codes.dot(&dictionary);
357 let error = (&x_f64 - &reconstruction).mapv(|x| x * x).sum().sqrt();
358
359 if (prev_error - error).abs() / prev_error.max(1e-10) < self.tol {
361 n_iter = iter + 1;
362 break;
363 }
364
365 prev_error = error;
366 n_iter = iter + 1;
367 }
368
369 self.dictionary = Some(dictionary);
370 self.n_iter = Some(n_iter);
371
372 Ok(())
373 }
374
375 pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
383 where
384 S: Data,
385 S::Elem: Float + NumCast,
386 {
387 if self.dictionary.is_none() {
388 return Err(TransformError::TransformationError(
389 "DictionaryLearning model has not been fitted".to_string(),
390 ));
391 }
392
393 let x_f64 = x.mapv(|v| NumCast::from(v).unwrap_or(0.0));
394 let dictionary = self.dictionary.as_ref().unwrap();
395
396 Ok(self.sparse_code_step(&x_f64, dictionary))
397 }
398
399 pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
407 where
408 S: Data,
409 S::Elem: Float + NumCast,
410 {
411 self.fit(x)?;
412 self.transform(x)
413 }
414
415 pub fn dictionary(&self) -> Option<&Array2<f64>> {
417 self.dictionary.as_ref()
418 }
419
420 pub fn n_iterations(&self) -> Option<usize> {
422 self.n_iter
423 }
424
425 pub fn inverse_transform(&self, sparsecodes: &Array2<f64>) -> Result<Array2<f64>> {
427 if self.dictionary.is_none() {
428 return Err(TransformError::TransformationError(
429 "DictionaryLearning model has not been fitted".to_string(),
430 ));
431 }
432
433 let dictionary = self.dictionary.as_ref().unwrap();
434 Ok(sparsecodes.dot(dictionary))
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441 use scirs2_core::ndarray::Array;
442
443 #[test]
444 #[ignore] fn test_dictionary_learning_basic() {
446 let n_samples = 100;
448 let n_features = 20;
449 let mut data = Vec::new();
450
451 for i in 0..n_samples {
452 for j in 0..n_features {
453 let t = j as f64 / n_features as f64 * 2.0 * std::f64::consts::PI;
454 let val = (t * (i as f64 / 10.0)).sin() + (2.0 * t * (i as f64 / 15.0)).cos();
455 data.push(val);
456 }
457 }
458
459 let x = Array::from_shape_vec((n_samples, n_features), data).unwrap();
460
461 let mut dict_learning = DictionaryLearning::new(10, 0.1)
462 .with_max_iter(50)
463 .with_random_state(42);
464
465 let sparse_codes = dict_learning.fit_transform(&x).unwrap();
466
467 assert_eq!(sparse_codes.shape(), &[n_samples, 10]);
469
470 let dictionary = dict_learning.dictionary().unwrap();
472 assert_eq!(dictionary.shape(), &[10, n_features]);
473
474 for i in 0..10 {
476 let norm = vector_norm(&dictionary.row(i).view(), 2).unwrap_or(0.0);
477 assert!((norm - 1.0).abs() < 1e-5);
478 }
479
480 let reconstructed = dict_learning.inverse_transform(&sparse_codes).unwrap();
482 assert_eq!(reconstructed.shape(), x.shape());
483 }
484
485 #[test]
486 fn test_dictionary_learning_sparsity() {
487 let x: Array2<f64> = Array::eye(20) * 2.0;
488
489 let mut dict_learning = DictionaryLearning::new(10, 0.05).with_max_iter(30);
490
491 let sparse_codes = dict_learning.fit_transform(&x).unwrap();
492
493 let n_nonzero = sparse_codes.iter().filter(|&&x| x.abs() > 1e-10).count();
495 let total_elements = sparse_codes.len();
496 let sparsity = 1.0 - (n_nonzero as f64 / total_elements as f64);
497
498 assert!(sparsity > 0.5);
500 }
501}