scirs2_transform/decomposition/
nmf.rs

1//! Non-negative Matrix Factorization (NMF) for decomposition and feature extraction
2//!
3//! NMF decomposes a non-negative matrix V into two non-negative matrices W and H
4//! such that V ≈ WH. This is useful for parts-based representation and interpretable
5//! feature extraction.
6
7use scirs2_core::ndarray::{Array2, ArrayBase, Data, Ix2};
8use scirs2_core::numeric::{Float, NumCast};
9use scirs2_core::random::Rng;
10
11use crate::error::{Result, TransformError};
12
13/// Non-negative Matrix Factorization (NMF)
14///
15/// Finds two non-negative matrices W and H whose product approximates the
16/// non-negative input matrix V. The objective function is minimized with
17/// multiplicative update rules.
18#[derive(Debug, Clone)]
19pub struct NMF {
20    /// Number of components (latent features)
21    n_components: usize,
22    /// Initialization method: 'random', 'nndsvd', 'nndsvda', 'nndsvdar'
23    init: String,
24    /// Solver: 'mu' (multiplicative update), 'cd' (coordinate descent)
25    solver: String,
26    /// Beta divergence parameter (0: Euclidean, 1: KL divergence, 2: Frobenius)
27    beta_loss: f64,
28    /// Maximum number of iterations
29    max_iter: usize,
30    /// Tolerance for stopping criteria
31    tol: f64,
32    /// Random state for reproducibility
33    random_state: Option<u64>,
34    /// Regularization parameter for components
35    alpha: f64,
36    /// L1 ratio for regularization (0: L2, 1: L1)
37    l1_ratio: f64,
38    /// The basis matrix W
39    components: Option<Array2<f64>>,
40    /// The coefficient matrix H
41    coefficients: Option<Array2<f64>>,
42    /// Reconstruction error
43    reconstruction_err: Option<f64>,
44    /// Number of iterations run
45    n_iter: Option<usize>,
46}
47
48impl NMF {
49    /// Creates a new NMF instance
50    ///
51    /// # Arguments
52    /// * `n_components` - Number of components to extract
53    pub fn new(ncomponents: usize) -> Self {
54        NMF {
55            n_components: ncomponents,
56            init: "random".to_string(),
57            solver: "mu".to_string(),
58            beta_loss: 2.0, // Frobenius norm
59            max_iter: 200,
60            tol: 1e-4,
61            random_state: None,
62            alpha: 0.0,
63            l1_ratio: 0.0,
64            components: None,
65            coefficients: None,
66            reconstruction_err: None,
67            n_iter: None,
68        }
69    }
70
71    /// Set the initialization method
72    pub fn with_init(mut self, init: &str) -> Self {
73        self.init = init.to_string();
74        self
75    }
76
77    /// Set the solver
78    pub fn with_solver(mut self, solver: &str) -> Self {
79        self.solver = solver.to_string();
80        self
81    }
82
83    /// Set the beta divergence parameter
84    pub fn with_beta_loss(mut self, beta: f64) -> Self {
85        self.beta_loss = beta;
86        self
87    }
88
89    /// Set maximum iterations
90    pub fn with_max_iter(mut self, maxiter: usize) -> Self {
91        self.max_iter = maxiter;
92        self
93    }
94
95    /// Set tolerance
96    pub fn with_tolerance(mut self, tol: f64) -> Self {
97        self.tol = tol;
98        self
99    }
100
101    /// Set random state
102    pub fn with_random_state(mut self, seed: u64) -> Self {
103        self.random_state = Some(seed);
104        self
105    }
106
107    /// Set regularization parameters
108    pub fn with_regularization(mut self, alpha: f64, l1ratio: f64) -> Self {
109        self.alpha = alpha;
110        self.l1_ratio = l1ratio;
111        self
112    }
113
114    /// Initialize matrices with random non-negative values
115    fn random_initialization(&self, v: &Array2<f64>) -> (Array2<f64>, Array2<f64>) {
116        let (n_samples, n_features) = (v.shape()[0], v.shape()[1]);
117        let mut rng = scirs2_core::random::rng();
118
119        let scale = (v.mean().unwrap() / self.n_components as f64).sqrt();
120
121        let mut w = Array2::zeros((n_samples, self.n_components));
122        let mut h = Array2::zeros((self.n_components, n_features));
123
124        for i in 0..n_samples {
125            for j in 0..self.n_components {
126                w[[i, j]] = rng.random::<f64>() * scale;
127            }
128        }
129
130        for i in 0..self.n_components {
131            for j in 0..n_features {
132                h[[i, j]] = rng.random::<f64>() * scale;
133            }
134        }
135
136        (w, h)
137    }
138
139    /// NNDSVD initialization (Non-negative Double Singular Value Decomposition)
140    fn nndsvd_initialization(&self, v: &Array2<f64>) -> Result<(Array2<f64>, Array2<f64>)> {
141        let (n_samples, n_features) = (v.shape()[0], v.shape()[1]);
142
143        // Perform SVD
144        let (u, s, vt) = match scirs2_linalg::svd::<f64>(&v.view(), true, None) {
145            Ok(result) => result,
146            Err(e) => return Err(TransformError::LinalgError(e)),
147        };
148
149        let mut w = Array2::zeros((n_samples, self.n_components));
150        let mut h = Array2::zeros((self.n_components, n_features));
151
152        // Use the first n_components singular vectors
153        for j in 0..self.n_components {
154            let x = u.column(j);
155            let y = vt.row(j);
156
157            // Make non-negative
158            let x_pos = x.mapv(|v| v.max(0.0));
159            let x_neg = x.mapv(|v| (-v).max(0.0));
160            let y_pos = y.mapv(|v| v.max(0.0));
161            let y_neg = y.mapv(|v| (-v).max(0.0));
162
163            let x_pos_norm = x_pos.dot(&x_pos).sqrt();
164            let x_neg_norm = x_neg.dot(&x_neg).sqrt();
165            let y_pos_norm = y_pos.dot(&y_pos).sqrt();
166            let y_neg_norm = y_neg.dot(&y_neg).sqrt();
167
168            let m_pos = x_pos_norm * y_pos_norm;
169            let m_neg = x_neg_norm * y_neg_norm;
170
171            if m_pos > m_neg {
172                for i in 0..n_samples {
173                    w[[i, j]] = (s[j].sqrt() * x_pos[i] / x_pos_norm).max(0.0);
174                }
175                for i in 0..n_features {
176                    h[[j, i]] = (s[j].sqrt() * y_pos[i] / y_pos_norm).max(0.0);
177                }
178            } else {
179                for i in 0..n_samples {
180                    w[[i, j]] = (s[j].sqrt() * x_neg[i] / x_neg_norm).max(0.0);
181                }
182                for i in 0..n_features {
183                    h[[j, i]] = (s[j].sqrt() * y_neg[i] / y_neg_norm).max(0.0);
184                }
185            }
186        }
187
188        Ok((w, h))
189    }
190
191    /// Initialize W and H matrices
192    fn initialize_matrices(&self, v: &Array2<f64>) -> Result<(Array2<f64>, Array2<f64>)> {
193        match self.init.as_str() {
194            "random" => Ok(self.random_initialization(v)),
195            "nndsvd" => self.nndsvd_initialization(v),
196            _ => Ok(self.random_initialization(v)),
197        }
198    }
199
200    /// Compute Frobenius norm loss
201    fn frobenius_loss(&self, v: &Array2<f64>, w: &Array2<f64>, h: &Array2<f64>) -> f64 {
202        let wh = w.dot(h);
203        let diff = v - &wh;
204        diff.mapv(|x| x * x).sum().sqrt()
205    }
206
207    /// Multiplicative update for W
208    fn update_w(&self, v: &Array2<f64>, w: &Array2<f64>, h: &Array2<f64>) -> Array2<f64> {
209        let eps = 1e-10;
210        let wh = w.dot(h);
211
212        // Numerator: V * H^T
213        let numerator = v.dot(&h.t());
214
215        // Denominator: W * H * H^T + regularization
216        let mut denominator = wh.dot(&h.t());
217
218        // Add L2 regularization
219        if self.alpha > 0.0 && self.l1_ratio < 1.0 {
220            let l2_reg = self.alpha * (1.0 - self.l1_ratio);
221            denominator = &denominator + &(w * l2_reg);
222        }
223
224        // Add L1 regularization
225        if self.alpha > 0.0 && self.l1_ratio > 0.0 {
226            let l1_reg = self.alpha * self.l1_ratio;
227            denominator = denominator.mapv(|x| x + l1_reg);
228        }
229
230        // Multiplicative update
231        let mut w_new = w * &(numerator / (denominator + eps));
232
233        // Ensure non-negativity
234        w_new.mapv_inplace(|x| x.max(eps));
235
236        w_new
237    }
238
239    /// Multiplicative update for H
240    fn update_h(&self, v: &Array2<f64>, w: &Array2<f64>, h: &Array2<f64>) -> Array2<f64> {
241        let eps = 1e-10;
242        let wh = w.dot(h);
243
244        // Numerator: W^T * V
245        let numerator = w.t().dot(v);
246
247        // Denominator: W^T * W * H + regularization
248        let mut denominator = w.t().dot(&wh);
249
250        // Add L2 regularization
251        if self.alpha > 0.0 && self.l1_ratio < 1.0 {
252            let l2_reg = self.alpha * (1.0 - self.l1_ratio);
253            denominator = &denominator + &(h * l2_reg);
254        }
255
256        // Add L1 regularization
257        if self.alpha > 0.0 && self.l1_ratio > 0.0 {
258            let l1_reg = self.alpha * self.l1_ratio;
259            denominator = denominator.mapv(|x| x + l1_reg);
260        }
261
262        // Multiplicative update
263        let mut h_new = h * &(numerator / (denominator + eps));
264
265        // Ensure non-negativity
266        h_new.mapv_inplace(|x| x.max(eps));
267
268        h_new
269    }
270
271    /// Coordinate descent update for W
272    fn update_w_cd(&self, v: &Array2<f64>, w: &Array2<f64>, h: &Array2<f64>) -> Array2<f64> {
273        let eps = 1e-10;
274        let (n_samples, n_components) = w.dim();
275        let mut w_new = w.clone();
276
277        // Precompute H * H^T for efficiency
278        let hht = h.dot(&h.t());
279
280        for i in 0..n_samples {
281            for j in 0..n_components {
282                // Compute residual without contribution from w[i,j]
283                let mut numerator = 0.0;
284                let mut denominator = hht[[j, j]];
285
286                // Compute v[i,:] * h[j,:] (numerator)
287                for k in 0..h.ncols() {
288                    numerator += v[[i, k]] * h[[j, k]];
289                }
290
291                // Compute w[i,:] * (H * H^T)[j,:] excluding w[i,j]
292                for k in 0..n_components {
293                    if k != j {
294                        numerator -= w_new[[i, k]] * hht[[k, j]];
295                    }
296                }
297
298                // Add regularization terms
299                if self.alpha > 0.0 {
300                    if self.l1_ratio > 0.0 {
301                        // L1 regularization (soft thresholding)
302                        let l1_penalty = self.alpha * self.l1_ratio;
303                        numerator -= l1_penalty;
304                    }
305                    if self.l1_ratio < 1.0 {
306                        // L2 regularization
307                        let l2_penalty = self.alpha * (1.0 - self.l1_ratio);
308                        denominator += l2_penalty;
309                        numerator -= l2_penalty * w_new[[i, j]];
310                    }
311                }
312
313                // Update w[i,j]
314                let new_val = if denominator > eps {
315                    (numerator / denominator).max(eps)
316                } else {
317                    eps
318                };
319
320                w_new[[i, j]] = new_val;
321            }
322        }
323
324        w_new
325    }
326
327    /// Coordinate descent update for H
328    fn update_h_cd(&self, v: &Array2<f64>, w: &Array2<f64>, h: &Array2<f64>) -> Array2<f64> {
329        let eps = 1e-10;
330        let (n_components, n_features) = h.dim();
331        let mut h_new = h.clone();
332
333        // Precompute W^T * W for efficiency
334        let wtw = w.t().dot(w);
335
336        for i in 0..n_components {
337            for j in 0..n_features {
338                // Compute residual without contribution from h[i,j]
339                let mut numerator = 0.0;
340                let mut denominator = wtw[[i, i]];
341
342                // Compute w[:,i]^T * v[:,j] (numerator)
343                for k in 0..w.nrows() {
344                    numerator += w[[k, i]] * v[[k, j]];
345                }
346
347                // Compute (W^T * W)[i,:] * h[:,j] excluding h[i,j]
348                for k in 0..n_components {
349                    if k != i {
350                        numerator -= wtw[[i, k]] * h_new[[k, j]];
351                    }
352                }
353
354                // Add regularization terms
355                if self.alpha > 0.0 {
356                    if self.l1_ratio > 0.0 {
357                        // L1 regularization (soft thresholding)
358                        let l1_penalty = self.alpha * self.l1_ratio;
359                        numerator -= l1_penalty;
360                    }
361                    if self.l1_ratio < 1.0 {
362                        // L2 regularization
363                        let l2_penalty = self.alpha * (1.0 - self.l1_ratio);
364                        denominator += l2_penalty;
365                        numerator -= l2_penalty * h_new[[i, j]];
366                    }
367                }
368
369                // Update h[i,j]
370                let new_val = if denominator > eps {
371                    (numerator / denominator).max(eps)
372                } else {
373                    eps
374                };
375
376                h_new[[i, j]] = new_val;
377            }
378        }
379
380        h_new
381    }
382
383    /// Fit NMF model to data
384    ///
385    /// # Arguments
386    /// * `x` - Input data matrix (must be non-negative)
387    ///
388    /// # Returns
389    /// * `Result<()>` - Ok if successful, Err otherwise
390    pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
391    where
392        S: Data,
393        S::Elem: Float + NumCast,
394    {
395        // Validate non-negativity before conversion
396        for elem in x.iter() {
397            let val = NumCast::from(*elem).unwrap_or(0.0);
398            if val < 0.0 {
399                return Err(TransformError::InvalidInput(
400                    "NMF requires non-negative input data".to_string(),
401                ));
402            }
403        }
404
405        // Convert to f64
406        let v = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
407
408        let (n_samples, n_features) = (v.shape()[0], v.shape()[1]);
409
410        if self.n_components > n_features.min(n_samples) {
411            return Err(TransformError::InvalidInput(format!(
412                "n_components={} must be <= min(n_samples={}, n_features={})",
413                self.n_components, n_samples, n_features
414            )));
415        }
416
417        // Initialize W and H
418        let (mut w, mut h) = self.initialize_matrices(&v)?;
419
420        let mut prev_error = self.frobenius_loss(&v, &w, &h);
421        let mut n_iter = 0;
422
423        // Optimization loop
424        for iter in 0..self.max_iter {
425            // Update W and H
426            if self.solver == "mu" {
427                h = self.update_h(&v, &w, &h);
428                w = self.update_w(&v, &w, &h);
429            } else if self.solver == "cd" {
430                h = self.update_h_cd(&v, &w, &h);
431                w = self.update_w_cd(&v, &w, &h);
432            } else {
433                return Err(TransformError::InvalidInput(format!(
434                    "Unknown solver '{}'. Supported solvers: 'mu', 'cd'",
435                    self.solver
436                )));
437            }
438
439            // Compute error
440            let error = self.frobenius_loss(&v, &w, &h);
441
442            // Check convergence
443            if (prev_error - error).abs() / prev_error.max(1e-10) < self.tol {
444                n_iter = iter + 1;
445                break;
446            }
447
448            prev_error = error;
449            n_iter = iter + 1;
450        }
451
452        self.components = Some(h);
453        self.coefficients = Some(w);
454        self.reconstruction_err = Some(prev_error);
455        self.n_iter = Some(n_iter);
456
457        Ok(())
458    }
459
460    /// Transform data using fitted NMF model
461    ///
462    /// # Arguments
463    /// * `x` - Input data matrix (must be non-negative)
464    ///
465    /// # Returns
466    /// * `Result<Array2<f64>>` - Transformed data (W matrix)
467    pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
468    where
469        S: Data,
470        S::Elem: Float + NumCast,
471    {
472        if self.components.is_none() {
473            return Err(TransformError::TransformationError(
474                "NMF model has not been fitted".to_string(),
475            ));
476        }
477
478        // Validate non-negativity before conversion
479        for elem in x.iter() {
480            let val = NumCast::from(*elem).unwrap_or(0.0);
481            if val < 0.0 {
482                return Err(TransformError::InvalidInput(
483                    "NMF requires non-negative input data".to_string(),
484                ));
485            }
486        }
487
488        // Convert to f64
489        let v = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
490
491        let h = self.components.as_ref().unwrap();
492        let n_samples = v.shape()[0];
493
494        // Initialize W randomly
495        let mut rng = scirs2_core::random::rng();
496
497        let scale = (v.mean().unwrap() / self.n_components as f64).sqrt();
498        let mut w = Array2::zeros((n_samples, self.n_components));
499
500        for i in 0..n_samples {
501            for j in 0..self.n_components {
502                w[[i, j]] = rng.random::<f64>() * scale;
503            }
504        }
505
506        // Update W while keeping H fixed
507        for _ in 0..self.max_iter {
508            w = self.update_w(&v, &w, h);
509        }
510
511        Ok(w)
512    }
513
514    /// Fit and transform in one step
515    ///
516    /// # Arguments
517    /// * `x` - Input data matrix (must be non-negative)
518    ///
519    /// # Returns
520    /// * `Result<Array2<f64>>` - Transformed data (W matrix)
521    pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
522    where
523        S: Data,
524        S::Elem: Float + NumCast,
525    {
526        self.fit(x)?;
527        Ok(self.coefficients.as_ref().unwrap().clone())
528    }
529
530    /// Get the components (H matrix)
531    pub fn components(&self) -> Option<&Array2<f64>> {
532        self.components.as_ref()
533    }
534
535    /// Get the coefficients (W matrix)
536    pub fn coefficients(&self) -> Option<&Array2<f64>> {
537        self.coefficients.as_ref()
538    }
539
540    /// Get reconstruction error
541    pub fn reconstruction_error(&self) -> Option<f64> {
542        self.reconstruction_err
543    }
544
545    /// Get number of iterations run
546    pub fn n_iterations(&self) -> Option<usize> {
547        self.n_iter
548    }
549
550    /// Inverse transform - reconstruct data from transformed representation
551    pub fn inverse_transform(&self, w: &Array2<f64>) -> Result<Array2<f64>> {
552        if self.components.is_none() {
553            return Err(TransformError::TransformationError(
554                "NMF model has not been fitted".to_string(),
555            ));
556        }
557
558        let h = self.components.as_ref().unwrap();
559        Ok(w.dot(h))
560    }
561}
562
563#[cfg(test)]
564mod tests {
565    use super::*;
566    use scirs2_core::ndarray::Array;
567
568    #[test]
569    fn test_nmf_basic() {
570        // Create non-negative data
571        let x = Array::from_shape_vec(
572            (6, 4),
573            vec![
574                1.0, 2.0, 3.0, 4.0, 2.0, 4.0, 6.0, 8.0, 3.0, 6.0, 9.0, 12.0, 4.0, 8.0, 12.0, 16.0,
575                5.0, 10.0, 15.0, 20.0, 6.0, 12.0, 18.0, 24.0,
576            ],
577        )
578        .unwrap();
579
580        let mut nmf = NMF::new(2).with_max_iter(100).with_random_state(42);
581
582        let w = nmf.fit_transform(&x).unwrap();
583
584        // Check dimensions
585        assert_eq!(w.shape(), &[6, 2]);
586
587        // Check non-negativity
588        for val in w.iter() {
589            assert!(*val >= 0.0);
590        }
591
592        // Check components
593        let h = nmf.components().unwrap();
594        assert_eq!(h.shape(), &[2, 4]);
595
596        for val in h.iter() {
597            assert!(*val >= 0.0);
598        }
599
600        // Check reconstruction
601        let x_reconstructed = nmf.inverse_transform(&w).unwrap();
602        assert_eq!(x_reconstructed.shape(), x.shape());
603    }
604
605    #[test]
606    fn test_nmf_regularization() {
607        let x = Array2::<f64>::eye(10) + 0.1; // Add small value to ensure positivity
608
609        let mut nmf = NMF::new(3).with_regularization(0.1, 0.5).with_max_iter(50);
610
611        let result = nmf.fit_transform(&x);
612        assert!(result.is_ok());
613
614        let w = result.unwrap();
615        assert_eq!(w.shape(), &[10, 3]);
616    }
617
618    #[test]
619    fn test_nmf_negative_input() {
620        let x = Array::from_shape_vec((3, 3), vec![1.0, 2.0, 3.0, -1.0, 5.0, 6.0, 7.0, 8.0, 9.0])
621            .unwrap();
622
623        let mut nmf = NMF::new(2);
624        let result = nmf.fit(&x);
625
626        assert!(result.is_err());
627        if let Err(e) = result {
628            assert!(e
629                .to_string()
630                .contains("NMF requires non-negative input data"));
631        }
632    }
633
634    #[test]
635    fn test_nmf_coordinate_descent() {
636        // Create non-negative data
637        let x = Array::from_shape_vec(
638            (6, 4),
639            vec![
640                1.0, 2.0, 3.0, 4.0, 2.0, 4.0, 6.0, 8.0, 3.0, 6.0, 9.0, 12.0, 4.0, 8.0, 12.0, 16.0,
641                5.0, 10.0, 15.0, 20.0, 6.0, 12.0, 18.0, 24.0,
642            ],
643        )
644        .unwrap();
645
646        let mut nmf_cd = NMF::new(2)
647            .with_solver("cd")
648            .with_max_iter(100)
649            .with_random_state(42);
650
651        let w_cd = nmf_cd.fit_transform(&x).unwrap();
652
653        // Check dimensions
654        assert_eq!(w_cd.shape(), &[6, 2]);
655
656        // Check non-negativity
657        for val in w_cd.iter() {
658            assert!(*val >= 0.0);
659        }
660
661        // Check components
662        let h_cd = nmf_cd.components().unwrap();
663        assert_eq!(h_cd.shape(), &[2, 4]);
664
665        for val in h_cd.iter() {
666            assert!(*val >= 0.0);
667        }
668
669        // Check reconstruction
670        let x_reconstructed = nmf_cd.inverse_transform(&w_cd).unwrap();
671        assert_eq!(x_reconstructed.shape(), x.shape());
672
673        // Compare with multiplicative update solver
674        let mut nmf_mu = NMF::new(2)
675            .with_solver("mu")
676            .with_max_iter(100)
677            .with_random_state(42);
678
679        let _w_mu = nmf_mu.fit_transform(&x).unwrap();
680
681        // Both should converge and produce valid decompositions
682        assert!(nmf_cd.reconstruction_error().unwrap() >= 0.0);
683        assert!(nmf_mu.reconstruction_error().unwrap() >= 0.0);
684    }
685
686    #[test]
687    fn test_nmf_invalid_solver() {
688        let x = Array2::<f64>::eye(3) + 0.1;
689        let mut nmf = NMF::new(2).with_solver("invalid");
690
691        let result = nmf.fit(&x);
692        assert!(result.is_err());
693        assert!(result.unwrap_err().to_string().contains("Unknown solver"));
694    }
695}