scirs2_transform/decomposition/
nmf.rs

1//! Non-negative Matrix Factorization (NMF) for decomposition and feature extraction
2//!
3//! NMF decomposes a non-negative matrix V into two non-negative matrices W and H
4//! such that V ≈ WH. This is useful for parts-based representation and interpretable
5//! feature extraction.
6
7use ndarray::{Array2, ArrayBase, Data, Ix2};
8use num_traits::{Float, NumCast};
9use rand::Rng;
10
11use crate::error::{Result, TransformError};
12// use statrs::statistics::Statistics; // TODO: Add statrs dependency
13
14/// Non-negative Matrix Factorization (NMF)
15///
16/// Finds two non-negative matrices W and H whose product approximates the
17/// non-negative input matrix V. The objective function is minimized with
18/// multiplicative update rules.
19#[derive(Debug, Clone)]
20pub struct NMF {
21    /// Number of components (latent features)
22    n_components: usize,
23    /// Initialization method: 'random', 'nndsvd', 'nndsvda', 'nndsvdar'
24    init: String,
25    /// Solver: 'mu' (multiplicative update), 'cd' (coordinate descent)
26    solver: String,
27    /// Beta divergence parameter (0: Euclidean, 1: KL divergence, 2: Frobenius)
28    beta_loss: f64,
29    /// Maximum number of iterations
30    max_iter: usize,
31    /// Tolerance for stopping criteria
32    tol: f64,
33    /// Random state for reproducibility
34    random_state: Option<u64>,
35    /// Regularization parameter for components
36    alpha: f64,
37    /// L1 ratio for regularization (0: L2, 1: L1)
38    l1_ratio: f64,
39    /// The basis matrix W
40    components: Option<Array2<f64>>,
41    /// The coefficient matrix H
42    coefficients: Option<Array2<f64>>,
43    /// Reconstruction error
44    reconstruction_err: Option<f64>,
45    /// Number of iterations run
46    n_iter: Option<usize>,
47}
48
49impl NMF {
50    /// Creates a new NMF instance
51    ///
52    /// # Arguments
53    /// * `n_components` - Number of components to extract
54    pub fn new(ncomponents: usize) -> Self {
55        NMF {
56            n_components: ncomponents,
57            init: "random".to_string(),
58            solver: "mu".to_string(),
59            beta_loss: 2.0, // Frobenius norm
60            max_iter: 200,
61            tol: 1e-4,
62            random_state: None,
63            alpha: 0.0,
64            l1_ratio: 0.0,
65            components: None,
66            coefficients: None,
67            reconstruction_err: None,
68            n_iter: None,
69        }
70    }
71
72    /// Set the initialization method
73    pub fn with_init(mut self, init: &str) -> Self {
74        self.init = init.to_string();
75        self
76    }
77
78    /// Set the solver
79    pub fn with_solver(mut self, solver: &str) -> Self {
80        self.solver = solver.to_string();
81        self
82    }
83
84    /// Set the beta divergence parameter
85    pub fn with_beta_loss(mut self, beta: f64) -> Self {
86        self.beta_loss = beta;
87        self
88    }
89
90    /// Set maximum iterations
91    pub fn with_max_iter(mut self, maxiter: usize) -> Self {
92        self.max_iter = maxiter;
93        self
94    }
95
96    /// Set tolerance
97    pub fn with_tolerance(mut self, tol: f64) -> Self {
98        self.tol = tol;
99        self
100    }
101
102    /// Set random state
103    pub fn with_random_state(mut self, seed: u64) -> Self {
104        self.random_state = Some(seed);
105        self
106    }
107
108    /// Set regularization parameters
109    pub fn with_regularization(mut self, alpha: f64, l1ratio: f64) -> Self {
110        self.alpha = alpha;
111        self.l1_ratio = l1ratio;
112        self
113    }
114
115    /// Initialize matrices with random non-negative values
116    fn random_initialization(&self, v: &Array2<f64>) -> (Array2<f64>, Array2<f64>) {
117        let (n_samples, n_features) = (v.shape()[0], v.shape()[1]);
118        let mut rng = rand::rng();
119
120        let scale = (v.mean().unwrap() / self.n_components as f64).sqrt();
121
122        let mut w = Array2::zeros((n_samples, self.n_components));
123        let mut h = Array2::zeros((self.n_components, n_features));
124
125        for i in 0..n_samples {
126            for j in 0..self.n_components {
127                w[[i, j]] = rng.random::<f64>() * scale;
128            }
129        }
130
131        for i in 0..self.n_components {
132            for j in 0..n_features {
133                h[[i, j]] = rng.random::<f64>() * scale;
134            }
135        }
136
137        (w, h)
138    }
139
140    /// NNDSVD initialization (Non-negative Double Singular Value Decomposition)
141    fn nndsvd_initialization(&self, v: &Array2<f64>) -> Result<(Array2<f64>, Array2<f64>)> {
142        let (n_samples, n_features) = (v.shape()[0], v.shape()[1]);
143
144        // Perform SVD
145        let (u, s, vt) = match scirs2_linalg::svd::<f64>(&v.view(), true, None) {
146            Ok(result) => result,
147            Err(e) => return Err(TransformError::LinalgError(e)),
148        };
149
150        let mut w = Array2::zeros((n_samples, self.n_components));
151        let mut h = Array2::zeros((self.n_components, n_features));
152
153        // Use the first n_components singular vectors
154        for j in 0..self.n_components {
155            let x = u.column(j);
156            let y = vt.row(j);
157
158            // Make non-negative
159            let x_pos = x.mapv(|v| v.max(0.0));
160            let x_neg = x.mapv(|v| (-v).max(0.0));
161            let y_pos = y.mapv(|v| v.max(0.0));
162            let y_neg = y.mapv(|v| (-v).max(0.0));
163
164            let x_pos_norm = x_pos.dot(&x_pos).sqrt();
165            let x_neg_norm = x_neg.dot(&x_neg).sqrt();
166            let y_pos_norm = y_pos.dot(&y_pos).sqrt();
167            let y_neg_norm = y_neg.dot(&y_neg).sqrt();
168
169            let m_pos = x_pos_norm * y_pos_norm;
170            let m_neg = x_neg_norm * y_neg_norm;
171
172            if m_pos > m_neg {
173                for i in 0..n_samples {
174                    w[[i, j]] = (s[j].sqrt() * x_pos[i] / x_pos_norm).max(0.0);
175                }
176                for i in 0..n_features {
177                    h[[j, i]] = (s[j].sqrt() * y_pos[i] / y_pos_norm).max(0.0);
178                }
179            } else {
180                for i in 0..n_samples {
181                    w[[i, j]] = (s[j].sqrt() * x_neg[i] / x_neg_norm).max(0.0);
182                }
183                for i in 0..n_features {
184                    h[[j, i]] = (s[j].sqrt() * y_neg[i] / y_neg_norm).max(0.0);
185                }
186            }
187        }
188
189        Ok((w, h))
190    }
191
192    /// Initialize W and H matrices
193    fn initialize_matrices(&self, v: &Array2<f64>) -> Result<(Array2<f64>, Array2<f64>)> {
194        match self.init.as_str() {
195            "random" => Ok(self.random_initialization(v)),
196            "nndsvd" => self.nndsvd_initialization(v),
197            _ => Ok(self.random_initialization(v)),
198        }
199    }
200
201    /// Compute Frobenius norm loss
202    fn frobenius_loss(&self, v: &Array2<f64>, w: &Array2<f64>, h: &Array2<f64>) -> f64 {
203        let wh = w.dot(h);
204        let diff = v - &wh;
205        diff.mapv(|x| x * x).sum().sqrt()
206    }
207
208    /// Multiplicative update for W
209    fn update_w(&self, v: &Array2<f64>, w: &Array2<f64>, h: &Array2<f64>) -> Array2<f64> {
210        let eps = 1e-10;
211        let wh = w.dot(h);
212
213        // Numerator: V * H^T
214        let numerator = v.dot(&h.t());
215
216        // Denominator: W * H * H^T + regularization
217        let mut denominator = wh.dot(&h.t());
218
219        // Add L2 regularization
220        if self.alpha > 0.0 && self.l1_ratio < 1.0 {
221            let l2_reg = self.alpha * (1.0 - self.l1_ratio);
222            denominator = &denominator + &(w * l2_reg);
223        }
224
225        // Add L1 regularization
226        if self.alpha > 0.0 && self.l1_ratio > 0.0 {
227            let l1_reg = self.alpha * self.l1_ratio;
228            denominator = denominator.mapv(|x| x + l1_reg);
229        }
230
231        // Multiplicative update
232        let mut w_new = w * &(numerator / (denominator + eps));
233
234        // Ensure non-negativity
235        w_new.mapv_inplace(|x| x.max(eps));
236
237        w_new
238    }
239
240    /// Multiplicative update for H
241    fn update_h(&self, v: &Array2<f64>, w: &Array2<f64>, h: &Array2<f64>) -> Array2<f64> {
242        let eps = 1e-10;
243        let wh = w.dot(h);
244
245        // Numerator: W^T * V
246        let numerator = w.t().dot(v);
247
248        // Denominator: W^T * W * H + regularization
249        let mut denominator = w.t().dot(&wh);
250
251        // Add L2 regularization
252        if self.alpha > 0.0 && self.l1_ratio < 1.0 {
253            let l2_reg = self.alpha * (1.0 - self.l1_ratio);
254            denominator = &denominator + &(h * l2_reg);
255        }
256
257        // Add L1 regularization
258        if self.alpha > 0.0 && self.l1_ratio > 0.0 {
259            let l1_reg = self.alpha * self.l1_ratio;
260            denominator = denominator.mapv(|x| x + l1_reg);
261        }
262
263        // Multiplicative update
264        let mut h_new = h * &(numerator / (denominator + eps));
265
266        // Ensure non-negativity
267        h_new.mapv_inplace(|x| x.max(eps));
268
269        h_new
270    }
271
272    /// Coordinate descent update for W
273    fn update_w_cd(&self, v: &Array2<f64>, w: &Array2<f64>, h: &Array2<f64>) -> Array2<f64> {
274        let eps = 1e-10;
275        let (n_samples, n_components) = w.dim();
276        let mut w_new = w.clone();
277
278        // Precompute H * H^T for efficiency
279        let hht = h.dot(&h.t());
280
281        for i in 0..n_samples {
282            for j in 0..n_components {
283                // Compute residual without contribution from w[i,j]
284                let mut numerator = 0.0;
285                let mut denominator = hht[[j, j]];
286
287                // Compute v[i,:] * h[j,:] (numerator)
288                for k in 0..h.ncols() {
289                    numerator += v[[i, k]] * h[[j, k]];
290                }
291
292                // Compute w[i,:] * (H * H^T)[j,:] excluding w[i,j]
293                for k in 0..n_components {
294                    if k != j {
295                        numerator -= w_new[[i, k]] * hht[[k, j]];
296                    }
297                }
298
299                // Add regularization terms
300                if self.alpha > 0.0 {
301                    if self.l1_ratio > 0.0 {
302                        // L1 regularization (soft thresholding)
303                        let l1_penalty = self.alpha * self.l1_ratio;
304                        numerator -= l1_penalty;
305                    }
306                    if self.l1_ratio < 1.0 {
307                        // L2 regularization
308                        let l2_penalty = self.alpha * (1.0 - self.l1_ratio);
309                        denominator += l2_penalty;
310                        numerator -= l2_penalty * w_new[[i, j]];
311                    }
312                }
313
314                // Update w[i,j]
315                let new_val = if denominator > eps {
316                    (numerator / denominator).max(eps)
317                } else {
318                    eps
319                };
320
321                w_new[[i, j]] = new_val;
322            }
323        }
324
325        w_new
326    }
327
328    /// Coordinate descent update for H
329    fn update_h_cd(&self, v: &Array2<f64>, w: &Array2<f64>, h: &Array2<f64>) -> Array2<f64> {
330        let eps = 1e-10;
331        let (n_components, n_features) = h.dim();
332        let mut h_new = h.clone();
333
334        // Precompute W^T * W for efficiency
335        let wtw = w.t().dot(w);
336
337        for i in 0..n_components {
338            for j in 0..n_features {
339                // Compute residual without contribution from h[i,j]
340                let mut numerator = 0.0;
341                let mut denominator = wtw[[i, i]];
342
343                // Compute w[:,i]^T * v[:,j] (numerator)
344                for k in 0..w.nrows() {
345                    numerator += w[[k, i]] * v[[k, j]];
346                }
347
348                // Compute (W^T * W)[i,:] * h[:,j] excluding h[i,j]
349                for k in 0..n_components {
350                    if k != i {
351                        numerator -= wtw[[i, k]] * h_new[[k, j]];
352                    }
353                }
354
355                // Add regularization terms
356                if self.alpha > 0.0 {
357                    if self.l1_ratio > 0.0 {
358                        // L1 regularization (soft thresholding)
359                        let l1_penalty = self.alpha * self.l1_ratio;
360                        numerator -= l1_penalty;
361                    }
362                    if self.l1_ratio < 1.0 {
363                        // L2 regularization
364                        let l2_penalty = self.alpha * (1.0 - self.l1_ratio);
365                        denominator += l2_penalty;
366                        numerator -= l2_penalty * h_new[[i, j]];
367                    }
368                }
369
370                // Update h[i,j]
371                let new_val = if denominator > eps {
372                    (numerator / denominator).max(eps)
373                } else {
374                    eps
375                };
376
377                h_new[[i, j]] = new_val;
378            }
379        }
380
381        h_new
382    }
383
384    /// Fit NMF model to data
385    ///
386    /// # Arguments
387    /// * `x` - Input data matrix (must be non-negative)
388    ///
389    /// # Returns
390    /// * `Result<()>` - Ok if successful, Err otherwise
391    pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
392    where
393        S: Data,
394        S::Elem: Float + NumCast,
395    {
396        // Validate non-negativity before conversion
397        for elem in x.iter() {
398            let val = num_traits::cast::<S::Elem, f64>(*elem).unwrap_or(0.0);
399            if val < 0.0 {
400                return Err(TransformError::InvalidInput(
401                    "NMF requires non-negative input data".to_string(),
402                ));
403            }
404        }
405
406        // Convert to f64
407        let v = x.mapv(|x| num_traits::cast::<S::Elem, f64>(x).unwrap_or(0.0));
408
409        let (n_samples, n_features) = (v.shape()[0], v.shape()[1]);
410
411        if self.n_components > n_features.min(n_samples) {
412            return Err(TransformError::InvalidInput(format!(
413                "n_components={} must be <= min(n_samples={}, n_features={})",
414                self.n_components, n_samples, n_features
415            )));
416        }
417
418        // Initialize W and H
419        let (mut w, mut h) = self.initialize_matrices(&v)?;
420
421        let mut prev_error = self.frobenius_loss(&v, &w, &h);
422        let mut n_iter = 0;
423
424        // Optimization loop
425        for iter in 0..self.max_iter {
426            // Update W and H
427            if self.solver == "mu" {
428                h = self.update_h(&v, &w, &h);
429                w = self.update_w(&v, &w, &h);
430            } else if self.solver == "cd" {
431                h = self.update_h_cd(&v, &w, &h);
432                w = self.update_w_cd(&v, &w, &h);
433            } else {
434                return Err(TransformError::InvalidInput(format!(
435                    "Unknown solver '{}'. Supported solvers: 'mu', 'cd'",
436                    self.solver
437                )));
438            }
439
440            // Compute error
441            let error = self.frobenius_loss(&v, &w, &h);
442
443            // Check convergence
444            if (prev_error - error).abs() / prev_error.max(1e-10) < self.tol {
445                n_iter = iter + 1;
446                break;
447            }
448
449            prev_error = error;
450            n_iter = iter + 1;
451        }
452
453        self.components = Some(h);
454        self.coefficients = Some(w);
455        self.reconstruction_err = Some(prev_error);
456        self.n_iter = Some(n_iter);
457
458        Ok(())
459    }
460
461    /// Transform data using fitted NMF model
462    ///
463    /// # Arguments
464    /// * `x` - Input data matrix (must be non-negative)
465    ///
466    /// # Returns
467    /// * `Result<Array2<f64>>` - Transformed data (W matrix)
468    pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
469    where
470        S: Data,
471        S::Elem: Float + NumCast,
472    {
473        if self.components.is_none() {
474            return Err(TransformError::TransformationError(
475                "NMF model has not been fitted".to_string(),
476            ));
477        }
478
479        // Validate non-negativity before conversion
480        for elem in x.iter() {
481            let val = num_traits::cast::<S::Elem, f64>(*elem).unwrap_or(0.0);
482            if val < 0.0 {
483                return Err(TransformError::InvalidInput(
484                    "NMF requires non-negative input data".to_string(),
485                ));
486            }
487        }
488
489        // Convert to f64
490        let v = x.mapv(|x| num_traits::cast::<S::Elem, f64>(x).unwrap_or(0.0));
491
492        let h = self.components.as_ref().unwrap();
493        let n_samples = v.shape()[0];
494
495        // Initialize W randomly
496        let mut rng = rand::rng();
497
498        let scale = (v.mean().unwrap() / self.n_components as f64).sqrt();
499        let mut w = Array2::zeros((n_samples, self.n_components));
500
501        for i in 0..n_samples {
502            for j in 0..self.n_components {
503                w[[i, j]] = rng.random::<f64>() * scale;
504            }
505        }
506
507        // Update W while keeping H fixed
508        for _ in 0..self.max_iter {
509            w = self.update_w(&v, &w, h);
510        }
511
512        Ok(w)
513    }
514
515    /// Fit and transform in one step
516    ///
517    /// # Arguments
518    /// * `x` - Input data matrix (must be non-negative)
519    ///
520    /// # Returns
521    /// * `Result<Array2<f64>>` - Transformed data (W matrix)
522    pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
523    where
524        S: Data,
525        S::Elem: Float + NumCast,
526    {
527        self.fit(x)?;
528        Ok(self.coefficients.as_ref().unwrap().clone())
529    }
530
531    /// Get the components (H matrix)
532    pub fn components(&self) -> Option<&Array2<f64>> {
533        self.components.as_ref()
534    }
535
536    /// Get the coefficients (W matrix)
537    pub fn coefficients(&self) -> Option<&Array2<f64>> {
538        self.coefficients.as_ref()
539    }
540
541    /// Get reconstruction error
542    pub fn reconstruction_error(&self) -> Option<f64> {
543        self.reconstruction_err
544    }
545
546    /// Get number of iterations run
547    pub fn n_iterations(&self) -> Option<usize> {
548        self.n_iter
549    }
550
551    /// Inverse transform - reconstruct data from transformed representation
552    pub fn inverse_transform(&self, w: &Array2<f64>) -> Result<Array2<f64>> {
553        if self.components.is_none() {
554            return Err(TransformError::TransformationError(
555                "NMF model has not been fitted".to_string(),
556            ));
557        }
558
559        let h = self.components.as_ref().unwrap();
560        Ok(w.dot(h))
561    }
562}
563
564#[cfg(test)]
565mod tests {
566    use super::*;
567    use ndarray::Array;
568
569    #[test]
570    fn test_nmf_basic() {
571        // Create non-negative data
572        let x = Array::from_shape_vec(
573            (6, 4),
574            vec![
575                1.0, 2.0, 3.0, 4.0, 2.0, 4.0, 6.0, 8.0, 3.0, 6.0, 9.0, 12.0, 4.0, 8.0, 12.0, 16.0,
576                5.0, 10.0, 15.0, 20.0, 6.0, 12.0, 18.0, 24.0,
577            ],
578        )
579        .unwrap();
580
581        let mut nmf = NMF::new(2).with_max_iter(100).with_random_state(42);
582
583        let w = nmf.fit_transform(&x).unwrap();
584
585        // Check dimensions
586        assert_eq!(w.shape(), &[6, 2]);
587
588        // Check non-negativity
589        for val in w.iter() {
590            assert!(*val >= 0.0);
591        }
592
593        // Check components
594        let h = nmf.components().unwrap();
595        assert_eq!(h.shape(), &[2, 4]);
596
597        for val in h.iter() {
598            assert!(*val >= 0.0);
599        }
600
601        // Check reconstruction
602        let x_reconstructed = nmf.inverse_transform(&w).unwrap();
603        assert_eq!(x_reconstructed.shape(), x.shape());
604    }
605
606    #[test]
607    fn test_nmf_regularization() {
608        let x = Array2::<f64>::eye(10) + 0.1; // Add small value to ensure positivity
609
610        let mut nmf = NMF::new(3).with_regularization(0.1, 0.5).with_max_iter(50);
611
612        let result = nmf.fit_transform(&x);
613        assert!(result.is_ok());
614
615        let w = result.unwrap();
616        assert_eq!(w.shape(), &[10, 3]);
617    }
618
619    #[test]
620    fn test_nmf_negative_input() {
621        let x = Array::from_shape_vec((3, 3), vec![1.0, 2.0, 3.0, -1.0, 5.0, 6.0, 7.0, 8.0, 9.0])
622            .unwrap();
623
624        let mut nmf = NMF::new(2);
625        let result = nmf.fit(&x);
626
627        assert!(result.is_err());
628        if let Err(e) = result {
629            assert!(e
630                .to_string()
631                .contains("NMF requires non-negative input data"));
632        }
633    }
634
635    #[test]
636    fn test_nmf_coordinate_descent() {
637        // Create non-negative data
638        let x = Array::from_shape_vec(
639            (6, 4),
640            vec![
641                1.0, 2.0, 3.0, 4.0, 2.0, 4.0, 6.0, 8.0, 3.0, 6.0, 9.0, 12.0, 4.0, 8.0, 12.0, 16.0,
642                5.0, 10.0, 15.0, 20.0, 6.0, 12.0, 18.0, 24.0,
643            ],
644        )
645        .unwrap();
646
647        let mut nmf_cd = NMF::new(2)
648            .with_solver("cd")
649            .with_max_iter(100)
650            .with_random_state(42);
651
652        let w_cd = nmf_cd.fit_transform(&x).unwrap();
653
654        // Check dimensions
655        assert_eq!(w_cd.shape(), &[6, 2]);
656
657        // Check non-negativity
658        for val in w_cd.iter() {
659            assert!(*val >= 0.0);
660        }
661
662        // Check components
663        let h_cd = nmf_cd.components().unwrap();
664        assert_eq!(h_cd.shape(), &[2, 4]);
665
666        for val in h_cd.iter() {
667            assert!(*val >= 0.0);
668        }
669
670        // Check reconstruction
671        let x_reconstructed = nmf_cd.inverse_transform(&w_cd).unwrap();
672        assert_eq!(x_reconstructed.shape(), x.shape());
673
674        // Compare with multiplicative update solver
675        let mut nmf_mu = NMF::new(2)
676            .with_solver("mu")
677            .with_max_iter(100)
678            .with_random_state(42);
679
680        let _w_mu = nmf_mu.fit_transform(&x).unwrap();
681
682        // Both should converge and produce valid decompositions
683        assert!(nmf_cd.reconstruction_error().unwrap() >= 0.0);
684        assert!(nmf_mu.reconstruction_error().unwrap() >= 0.0);
685    }
686
687    #[test]
688    fn test_nmf_invalid_solver() {
689        let x = Array2::<f64>::eye(3) + 0.1;
690        let mut nmf = NMF::new(2).with_solver("invalid");
691
692        let result = nmf.fit(&x);
693        assert!(result.is_err());
694        assert!(result.unwrap_err().to_string().contains("Unknown solver"));
695    }
696}