scirs2_transform/decomposition/
dictionary_learning.rs

1//! Dictionary Learning for sparse coding and representation
2//!
3//! Dictionary Learning finds a sparse representation for the input data as a linear
4//! combination of basic elements called atoms. The atoms compose a dictionary.
5//! This is useful for sparse coding, denoising, and feature extraction.
6
7use scirs2_core::ndarray::{Array1, Array2, ArrayBase, Data, Ix2};
8use scirs2_core::numeric::{Float, NumCast};
9use scirs2_core::random::Rng;
10use scirs2_linalg::{svd, vector_norm};
11
12use crate::error::{Result, TransformError};
13
14/// Dictionary Learning for sparse representation
15///
16/// Finds a dictionary matrix D and sparse code matrix Alpha such that
17/// X ≈ D * Alpha, where Alpha is sparse. This is solved by alternating
18/// between sparse coding and dictionary update steps.
19#[derive(Debug, Clone)]
20pub struct DictionaryLearning {
21    /// Number of dictionary atoms to extract
22    n_components: usize,
23    /// Sparsity controlling parameter
24    alpha: f64,
25    /// Maximum number of iterations
26    max_iter: usize,
27    /// Tolerance for stopping criteria
28    tol: f64,
29    /// Algorithm for sparse coding: 'omp', 'lasso_lars', 'lasso_cd'
30    transform_algorithm: String,
31    /// Random state for reproducibility
32    random_state: Option<u64>,
33    /// Whether to shuffle data before each epoch
34    shuffle: bool,
35    /// The learned dictionary
36    dictionary: Option<Array2<f64>>,
37    /// Number of iterations run
38    n_iter: Option<usize>,
39}
40
41impl DictionaryLearning {
42    /// Creates a new DictionaryLearning instance
43    ///
44    /// # Arguments
45    /// * `n_components` - Number of dictionary elements to extract
46    /// * `alpha` - Sparsity controlling parameter
47    pub fn new(ncomponents: usize, alpha: f64) -> Self {
48        DictionaryLearning {
49            n_components: ncomponents,
50            alpha,
51            max_iter: 1000,
52            tol: 1e-4,
53            transform_algorithm: "omp".to_string(),
54            random_state: None,
55            shuffle: true,
56            dictionary: None,
57            n_iter: None,
58        }
59    }
60
61    /// Set maximum iterations
62    pub fn with_max_iter(mut self, maxiter: usize) -> Self {
63        self.max_iter = maxiter;
64        self
65    }
66
67    /// Set tolerance
68    pub fn with_tolerance(mut self, tol: f64) -> Self {
69        self.tol = tol;
70        self
71    }
72
73    /// Set transform algorithm
74    pub fn with_transform_algorithm(mut self, algorithm: &str) -> Self {
75        self.transform_algorithm = algorithm.to_string();
76        self
77    }
78
79    /// Set random state
80    pub fn with_random_state(mut self, seed: u64) -> Self {
81        self.random_state = Some(seed);
82        self
83    }
84
85    /// Set whether to shuffle data
86    pub fn with_shuffle(mut self, shuffle: bool) -> Self {
87        self.shuffle = shuffle;
88        self
89    }
90
91    /// Initialize dictionary with random patches from data
92    fn initialize_dictionary(&self, x: &Array2<f64>) -> Array2<f64> {
93        let n_features = x.shape()[1];
94        let n_samples = x.shape()[0];
95
96        let mut rng = scirs2_core::random::rng();
97
98        let mut dictionary = Array2::zeros((self.n_components, n_features));
99
100        // Select random samples as initial dictionary atoms
101        for i in 0..self.n_components {
102            let idx = rng.gen_range(0..n_samples);
103            dictionary.row_mut(i).assign(&x.row(idx));
104
105            // Normalize atom
106            let norm = vector_norm(&dictionary.row(i).view(), 2).unwrap_or(0.0);
107            if norm > 1e-10 {
108                dictionary.row_mut(i).mapv_inplace(|x| x / norm);
109            }
110        }
111
112        dictionary
113    }
114
115    /// Orthogonal Matching Pursuit (OMP) for sparse coding
116    fn omp_sparse_code(
117        &self,
118        x: &Array1<f64>,
119        dictionary: &Array2<f64>,
120        n_nonzero_coefs: usize,
121    ) -> Array1<f64> {
122        let n_atoms = dictionary.shape()[0];
123        let mut residual = x.clone();
124        let mut sparse_code = Array1::zeros(n_atoms);
125        let mut selected_atoms = Vec::new();
126
127        for _ in 0..n_nonzero_coefs.min(n_atoms) {
128            // Find atom with highest correlation to residual
129            let mut best_atom = 0;
130            let mut best_correlation = 0.0;
131
132            for j in 0..n_atoms {
133                if selected_atoms.contains(&j) {
134                    continue;
135                }
136
137                let correlation = residual.dot(&dictionary.row(j)).abs();
138                if correlation > best_correlation {
139                    best_correlation = correlation;
140                    best_atom = j;
141                }
142            }
143
144            if best_correlation < 1e-10 {
145                break;
146            }
147
148            selected_atoms.push(best_atom);
149
150            // Solve least squares for selected atoms
151            if selected_atoms.len() == 1 {
152                // Simple case: single atom
153                let atom = dictionary.row(best_atom);
154                let coef = x.dot(&atom) / atom.dot(&atom);
155                sparse_code[best_atom] = coef;
156                residual = x - &(atom.to_owned() * coef);
157            } else {
158                // Multiple atoms: solve least squares
159                let n_selected = selected_atoms.len();
160                let mut sub_dictionary = Array2::zeros((n_selected, dictionary.shape()[1]));
161
162                for (i, &atom_idx) in selected_atoms.iter().enumerate() {
163                    sub_dictionary.row_mut(i).assign(&dictionary.row(atom_idx));
164                }
165
166                // Solve X = D^T * alpha using normal equations
167                let gram = sub_dictionary.dot(&sub_dictionary.t());
168                let proj = sub_dictionary.dot(&x.view());
169
170                // Simple least squares solver (for small systems)
171                let alpha = self.solve_small_least_squares(&gram, &proj);
172
173                // Update sparse code and residual
174                sparse_code.fill(0.0);
175                for (i, &atom_idx) in selected_atoms.iter().enumerate() {
176                    sparse_code[atom_idx] = alpha[i];
177                }
178
179                residual = x - &dictionary.t().dot(&sparse_code);
180            }
181        }
182
183        sparse_code
184    }
185
186    /// Simple least squares solver for small systems
187    fn solve_small_least_squares(&self, a: &Array2<f64>, b: &Array1<f64>) -> Array1<f64> {
188        let n = a.shape()[0];
189        let mut result = b.clone();
190
191        // LU decomposition (simplified for small systems)
192        let mut lu = a.clone();
193        let mut perm = (0..n).collect::<Vec<_>>();
194
195        // Forward elimination
196        for k in 0..n - 1 {
197            // Find pivot
198            let mut max_idx = k;
199            let mut max_val = lu[[k, k]].abs();
200            for i in k + 1..n {
201                if lu[[i, k]].abs() > max_val {
202                    max_val = lu[[i, k]].abs();
203                    max_idx = i;
204                }
205            }
206
207            // Swap rows
208            if max_idx != k {
209                perm.swap(k, max_idx);
210                for j in 0..n {
211                    let tmp = lu[[k, j]];
212                    lu[[k, j]] = lu[[max_idx, j]];
213                    lu[[max_idx, j]] = tmp;
214                }
215                let tmp = result[k];
216                result[k] = result[max_idx];
217                result[max_idx] = tmp;
218            }
219
220            // Eliminate
221            for i in k + 1..n {
222                let factor = lu[[i, k]] / lu[[k, k]];
223                for j in k + 1..n {
224                    lu[[i, j]] -= factor * lu[[k, j]];
225                }
226                result[i] -= factor * result[k];
227            }
228        }
229
230        // Back substitution
231        for i in (0..n).rev() {
232            for j in i + 1..n {
233                result[i] -= lu[[i, j]] * result[j];
234            }
235            result[i] /= lu[[i, i]];
236        }
237
238        result
239    }
240
241    /// Sparse coding step: find sparse codes for all samples
242    fn sparse_code_step(&self, x: &Array2<f64>, dictionary: &Array2<f64>) -> Array2<f64> {
243        let n_samples = x.shape()[0];
244        let n_atoms = dictionary.shape()[0];
245        let mut codes = Array2::zeros((n_samples, n_atoms));
246
247        // Determine number of non-zero coefficients
248        let n_nonzero_coefs = (self.alpha * n_atoms as f64).ceil() as usize;
249
250        // Sparse code each sample
251        for i in 0..n_samples {
252            let sparse_code =
253                self.omp_sparse_code(&x.row(i).to_owned(), dictionary, n_nonzero_coefs);
254            codes.row_mut(i).assign(&sparse_code);
255        }
256
257        codes
258    }
259
260    /// Dictionary update step using SVD
261    fn dictionary_update_step(
262        &self,
263        x: &Array2<f64>,
264        sparse_codes: &mut Array2<f64>,
265        dictionary: &mut Array2<f64>,
266    ) {
267        let n_atoms = dictionary.shape()[0];
268        let n_features = dictionary.shape()[1];
269
270        for k in 0..n_atoms {
271            // Find samples that use this atom
272            let mut using_samples = Vec::new();
273            for i in 0..sparse_codes.shape()[0] {
274                if sparse_codes[[i, k]].abs() > 1e-10 {
275                    using_samples.push(i);
276                }
277            }
278
279            if using_samples.is_empty() {
280                continue;
281            }
282
283            // Compute residual without atom k
284            let mut residual = Array2::zeros((using_samples.len(), n_features));
285            for (idx, &i) in using_samples.iter().enumerate() {
286                let mut r = x.row(i).to_owned();
287                for j in 0..n_atoms {
288                    if j != k {
289                        r = r - dictionary.row(j).to_owned() * sparse_codes[[i, j]];
290                    }
291                }
292                residual.row_mut(idx).assign(&r);
293            }
294
295            // Update atom using SVD
296            if residual.shape()[0] > 0 {
297                match svd::<f64>(&residual.view(), false, Some(1)) {
298                    Ok((u, s, vt)) => {
299                        // Update dictionary atom
300                        dictionary.row_mut(k).assign(&vt.row(0));
301
302                        // Update sparse _codes
303                        for (idx, &i) in using_samples.iter().enumerate() {
304                            sparse_codes[[i, k]] = u[[idx, 0]] * s[0];
305                        }
306                    }
307                    Err(_) => {
308                        // If SVD fails, normalize current atom
309                        let norm = vector_norm(&dictionary.row(k).view(), 2).unwrap_or(0.0);
310                        if norm > 1e-10 {
311                            dictionary.row_mut(k).mapv_inplace(|x| x / norm);
312                        }
313                    }
314                }
315            }
316        }
317    }
318
319    /// Fit the dictionary learning model
320    ///
321    /// # Arguments
322    /// * `x` - Input data matrix
323    ///
324    /// # Returns
325    /// * `Result<()>` - Ok if successful, Err otherwise
326    pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
327    where
328        S: Data,
329        S::Elem: Float + NumCast,
330    {
331        let x_f64 = x.mapv(|v| NumCast::from(v).unwrap_or(0.0));
332        let _n_samples = x_f64.shape()[0];
333        let n_features = x_f64.shape()[1];
334
335        if self.n_components > n_features {
336            return Err(TransformError::InvalidInput(format!(
337                "n_components={} must be <= n_features={}",
338                self.n_components, n_features
339            )));
340        }
341
342        // Initialize dictionary
343        let mut dictionary = self.initialize_dictionary(&x_f64);
344        let mut prev_error = f64::INFINITY;
345        let mut n_iter = 0;
346
347        // Main optimization loop
348        for iter in 0..self.max_iter {
349            // Sparse coding step
350            let mut sparse_codes = self.sparse_code_step(&x_f64, &dictionary);
351
352            // Dictionary update step
353            self.dictionary_update_step(&x_f64, &mut sparse_codes, &mut dictionary);
354
355            // Compute reconstruction error
356            let reconstruction = sparse_codes.dot(&dictionary);
357            let error = (&x_f64 - &reconstruction).mapv(|x| x * x).sum().sqrt();
358
359            // Check convergence
360            if (prev_error - error).abs() / prev_error.max(1e-10) < self.tol {
361                n_iter = iter + 1;
362                break;
363            }
364
365            prev_error = error;
366            n_iter = iter + 1;
367        }
368
369        self.dictionary = Some(dictionary);
370        self.n_iter = Some(n_iter);
371
372        Ok(())
373    }
374
375    /// Transform data to sparse codes
376    ///
377    /// # Arguments
378    /// * `x` - Input data matrix
379    ///
380    /// # Returns
381    /// * `Result<Array2<f64>>` - Sparse codes
382    pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
383    where
384        S: Data,
385        S::Elem: Float + NumCast,
386    {
387        if self.dictionary.is_none() {
388            return Err(TransformError::TransformationError(
389                "DictionaryLearning model has not been fitted".to_string(),
390            ));
391        }
392
393        let x_f64 = x.mapv(|v| NumCast::from(v).unwrap_or(0.0));
394        let dictionary = self.dictionary.as_ref().unwrap();
395
396        Ok(self.sparse_code_step(&x_f64, dictionary))
397    }
398
399    /// Fit and transform in one step
400    ///
401    /// # Arguments
402    /// * `x` - Input data matrix
403    ///
404    /// # Returns
405    /// * `Result<Array2<f64>>` - Sparse codes
406    pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
407    where
408        S: Data,
409        S::Elem: Float + NumCast,
410    {
411        self.fit(x)?;
412        self.transform(x)
413    }
414
415    /// Get the learned dictionary
416    pub fn dictionary(&self) -> Option<&Array2<f64>> {
417        self.dictionary.as_ref()
418    }
419
420    /// Get number of iterations run
421    pub fn n_iterations(&self) -> Option<usize> {
422        self.n_iter
423    }
424
425    /// Reconstruct data from sparse codes
426    pub fn inverse_transform(&self, sparsecodes: &Array2<f64>) -> Result<Array2<f64>> {
427        if self.dictionary.is_none() {
428            return Err(TransformError::TransformationError(
429                "DictionaryLearning model has not been fitted".to_string(),
430            ));
431        }
432
433        let dictionary = self.dictionary.as_ref().unwrap();
434        Ok(sparsecodes.dot(dictionary))
435    }
436}
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441    use scirs2_core::ndarray::Array;
442
443    #[test]
444    #[ignore] // Slow test - dictionary learning takes ~60s
445    fn test_dictionary_learning_basic() {
446        // Create synthetic data as sum of sinusoids
447        let n_samples = 100;
448        let n_features = 20;
449        let mut data = Vec::new();
450
451        for i in 0..n_samples {
452            for j in 0..n_features {
453                let t = j as f64 / n_features as f64 * 2.0 * std::f64::consts::PI;
454                let val = (t * (i as f64 / 10.0)).sin() + (2.0 * t * (i as f64 / 15.0)).cos();
455                data.push(val);
456            }
457        }
458
459        let x = Array::from_shape_vec((n_samples, n_features), data).unwrap();
460
461        let mut dict_learning = DictionaryLearning::new(10, 0.1)
462            .with_max_iter(50)
463            .with_random_state(42);
464
465        let sparse_codes = dict_learning.fit_transform(&x).unwrap();
466
467        // Check dimensions
468        assert_eq!(sparse_codes.shape(), &[n_samples, 10]);
469
470        // Check dictionary
471        let dictionary = dict_learning.dictionary().unwrap();
472        assert_eq!(dictionary.shape(), &[10, n_features]);
473
474        // Check that dictionary atoms are normalized
475        for i in 0..10 {
476            let norm = vector_norm(&dictionary.row(i).view(), 2).unwrap_or(0.0);
477            assert!((norm - 1.0).abs() < 1e-5);
478        }
479
480        // Check reconstruction
481        let reconstructed = dict_learning.inverse_transform(&sparse_codes).unwrap();
482        assert_eq!(reconstructed.shape(), x.shape());
483    }
484
485    #[test]
486    fn test_dictionary_learning_sparsity() {
487        let x: Array2<f64> = Array::eye(20) * 2.0;
488
489        let mut dict_learning = DictionaryLearning::new(10, 0.05).with_max_iter(30);
490
491        let sparse_codes = dict_learning.fit_transform(&x).unwrap();
492
493        // Check sparsity: most elements should be zero
494        let n_nonzero = sparse_codes.iter().filter(|&&x| x.abs() > 1e-10).count();
495        let total_elements = sparse_codes.len();
496        let sparsity = 1.0 - (n_nonzero as f64 / total_elements as f64);
497
498        // Should be quite sparse
499        assert!(sparsity > 0.5);
500    }
501}