1use ndarray::{Array2, ArrayBase, Data, Ix2};
8use num_traits::{Float, NumCast};
9use rand::Rng;
10
11use crate::error::{Result, TransformError};
12#[derive(Debug, Clone)]
20pub struct NMF {
21 n_components: usize,
23 init: String,
25 solver: String,
27 beta_loss: f64,
29 max_iter: usize,
31 tol: f64,
33 random_state: Option<u64>,
35 alpha: f64,
37 l1_ratio: f64,
39 components: Option<Array2<f64>>,
41 coefficients: Option<Array2<f64>>,
43 reconstruction_err: Option<f64>,
45 n_iter: Option<usize>,
47}
48
49impl NMF {
50 pub fn new(ncomponents: usize) -> Self {
55 NMF {
56 n_components: ncomponents,
57 init: "random".to_string(),
58 solver: "mu".to_string(),
59 beta_loss: 2.0, max_iter: 200,
61 tol: 1e-4,
62 random_state: None,
63 alpha: 0.0,
64 l1_ratio: 0.0,
65 components: None,
66 coefficients: None,
67 reconstruction_err: None,
68 n_iter: None,
69 }
70 }
71
72 pub fn with_init(mut self, init: &str) -> Self {
74 self.init = init.to_string();
75 self
76 }
77
78 pub fn with_solver(mut self, solver: &str) -> Self {
80 self.solver = solver.to_string();
81 self
82 }
83
84 pub fn with_beta_loss(mut self, beta: f64) -> Self {
86 self.beta_loss = beta;
87 self
88 }
89
90 pub fn with_max_iter(mut self, maxiter: usize) -> Self {
92 self.max_iter = maxiter;
93 self
94 }
95
96 pub fn with_tolerance(mut self, tol: f64) -> Self {
98 self.tol = tol;
99 self
100 }
101
102 pub fn with_random_state(mut self, seed: u64) -> Self {
104 self.random_state = Some(seed);
105 self
106 }
107
108 pub fn with_regularization(mut self, alpha: f64, l1ratio: f64) -> Self {
110 self.alpha = alpha;
111 self.l1_ratio = l1ratio;
112 self
113 }
114
115 fn random_initialization(&self, v: &Array2<f64>) -> (Array2<f64>, Array2<f64>) {
117 let (n_samples, n_features) = (v.shape()[0], v.shape()[1]);
118 let mut rng = rand::rng();
119
120 let scale = (v.mean().unwrap() / self.n_components as f64).sqrt();
121
122 let mut w = Array2::zeros((n_samples, self.n_components));
123 let mut h = Array2::zeros((self.n_components, n_features));
124
125 for i in 0..n_samples {
126 for j in 0..self.n_components {
127 w[[i, j]] = rng.random::<f64>() * scale;
128 }
129 }
130
131 for i in 0..self.n_components {
132 for j in 0..n_features {
133 h[[i, j]] = rng.random::<f64>() * scale;
134 }
135 }
136
137 (w, h)
138 }
139
140 fn nndsvd_initialization(&self, v: &Array2<f64>) -> Result<(Array2<f64>, Array2<f64>)> {
142 let (n_samples, n_features) = (v.shape()[0], v.shape()[1]);
143
144 let (u, s, vt) = match scirs2_linalg::svd::<f64>(&v.view(), true, None) {
146 Ok(result) => result,
147 Err(e) => return Err(TransformError::LinalgError(e)),
148 };
149
150 let mut w = Array2::zeros((n_samples, self.n_components));
151 let mut h = Array2::zeros((self.n_components, n_features));
152
153 for j in 0..self.n_components {
155 let x = u.column(j);
156 let y = vt.row(j);
157
158 let x_pos = x.mapv(|v| v.max(0.0));
160 let x_neg = x.mapv(|v| (-v).max(0.0));
161 let y_pos = y.mapv(|v| v.max(0.0));
162 let y_neg = y.mapv(|v| (-v).max(0.0));
163
164 let x_pos_norm = x_pos.dot(&x_pos).sqrt();
165 let x_neg_norm = x_neg.dot(&x_neg).sqrt();
166 let y_pos_norm = y_pos.dot(&y_pos).sqrt();
167 let y_neg_norm = y_neg.dot(&y_neg).sqrt();
168
169 let m_pos = x_pos_norm * y_pos_norm;
170 let m_neg = x_neg_norm * y_neg_norm;
171
172 if m_pos > m_neg {
173 for i in 0..n_samples {
174 w[[i, j]] = (s[j].sqrt() * x_pos[i] / x_pos_norm).max(0.0);
175 }
176 for i in 0..n_features {
177 h[[j, i]] = (s[j].sqrt() * y_pos[i] / y_pos_norm).max(0.0);
178 }
179 } else {
180 for i in 0..n_samples {
181 w[[i, j]] = (s[j].sqrt() * x_neg[i] / x_neg_norm).max(0.0);
182 }
183 for i in 0..n_features {
184 h[[j, i]] = (s[j].sqrt() * y_neg[i] / y_neg_norm).max(0.0);
185 }
186 }
187 }
188
189 Ok((w, h))
190 }
191
192 fn initialize_matrices(&self, v: &Array2<f64>) -> Result<(Array2<f64>, Array2<f64>)> {
194 match self.init.as_str() {
195 "random" => Ok(self.random_initialization(v)),
196 "nndsvd" => self.nndsvd_initialization(v),
197 _ => Ok(self.random_initialization(v)),
198 }
199 }
200
201 fn frobenius_loss(&self, v: &Array2<f64>, w: &Array2<f64>, h: &Array2<f64>) -> f64 {
203 let wh = w.dot(h);
204 let diff = v - &wh;
205 diff.mapv(|x| x * x).sum().sqrt()
206 }
207
208 fn update_w(&self, v: &Array2<f64>, w: &Array2<f64>, h: &Array2<f64>) -> Array2<f64> {
210 let eps = 1e-10;
211 let wh = w.dot(h);
212
213 let numerator = v.dot(&h.t());
215
216 let mut denominator = wh.dot(&h.t());
218
219 if self.alpha > 0.0 && self.l1_ratio < 1.0 {
221 let l2_reg = self.alpha * (1.0 - self.l1_ratio);
222 denominator = &denominator + &(w * l2_reg);
223 }
224
225 if self.alpha > 0.0 && self.l1_ratio > 0.0 {
227 let l1_reg = self.alpha * self.l1_ratio;
228 denominator = denominator.mapv(|x| x + l1_reg);
229 }
230
231 let mut w_new = w * &(numerator / (denominator + eps));
233
234 w_new.mapv_inplace(|x| x.max(eps));
236
237 w_new
238 }
239
240 fn update_h(&self, v: &Array2<f64>, w: &Array2<f64>, h: &Array2<f64>) -> Array2<f64> {
242 let eps = 1e-10;
243 let wh = w.dot(h);
244
245 let numerator = w.t().dot(v);
247
248 let mut denominator = w.t().dot(&wh);
250
251 if self.alpha > 0.0 && self.l1_ratio < 1.0 {
253 let l2_reg = self.alpha * (1.0 - self.l1_ratio);
254 denominator = &denominator + &(h * l2_reg);
255 }
256
257 if self.alpha > 0.0 && self.l1_ratio > 0.0 {
259 let l1_reg = self.alpha * self.l1_ratio;
260 denominator = denominator.mapv(|x| x + l1_reg);
261 }
262
263 let mut h_new = h * &(numerator / (denominator + eps));
265
266 h_new.mapv_inplace(|x| x.max(eps));
268
269 h_new
270 }
271
272 fn update_w_cd(&self, v: &Array2<f64>, w: &Array2<f64>, h: &Array2<f64>) -> Array2<f64> {
274 let eps = 1e-10;
275 let (n_samples, n_components) = w.dim();
276 let mut w_new = w.clone();
277
278 let hht = h.dot(&h.t());
280
281 for i in 0..n_samples {
282 for j in 0..n_components {
283 let mut numerator = 0.0;
285 let mut denominator = hht[[j, j]];
286
287 for k in 0..h.ncols() {
289 numerator += v[[i, k]] * h[[j, k]];
290 }
291
292 for k in 0..n_components {
294 if k != j {
295 numerator -= w_new[[i, k]] * hht[[k, j]];
296 }
297 }
298
299 if self.alpha > 0.0 {
301 if self.l1_ratio > 0.0 {
302 let l1_penalty = self.alpha * self.l1_ratio;
304 numerator -= l1_penalty;
305 }
306 if self.l1_ratio < 1.0 {
307 let l2_penalty = self.alpha * (1.0 - self.l1_ratio);
309 denominator += l2_penalty;
310 numerator -= l2_penalty * w_new[[i, j]];
311 }
312 }
313
314 let new_val = if denominator > eps {
316 (numerator / denominator).max(eps)
317 } else {
318 eps
319 };
320
321 w_new[[i, j]] = new_val;
322 }
323 }
324
325 w_new
326 }
327
328 fn update_h_cd(&self, v: &Array2<f64>, w: &Array2<f64>, h: &Array2<f64>) -> Array2<f64> {
330 let eps = 1e-10;
331 let (n_components, n_features) = h.dim();
332 let mut h_new = h.clone();
333
334 let wtw = w.t().dot(w);
336
337 for i in 0..n_components {
338 for j in 0..n_features {
339 let mut numerator = 0.0;
341 let mut denominator = wtw[[i, i]];
342
343 for k in 0..w.nrows() {
345 numerator += w[[k, i]] * v[[k, j]];
346 }
347
348 for k in 0..n_components {
350 if k != i {
351 numerator -= wtw[[i, k]] * h_new[[k, j]];
352 }
353 }
354
355 if self.alpha > 0.0 {
357 if self.l1_ratio > 0.0 {
358 let l1_penalty = self.alpha * self.l1_ratio;
360 numerator -= l1_penalty;
361 }
362 if self.l1_ratio < 1.0 {
363 let l2_penalty = self.alpha * (1.0 - self.l1_ratio);
365 denominator += l2_penalty;
366 numerator -= l2_penalty * h_new[[i, j]];
367 }
368 }
369
370 let new_val = if denominator > eps {
372 (numerator / denominator).max(eps)
373 } else {
374 eps
375 };
376
377 h_new[[i, j]] = new_val;
378 }
379 }
380
381 h_new
382 }
383
384 pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
392 where
393 S: Data,
394 S::Elem: Float + NumCast,
395 {
396 for elem in x.iter() {
398 let val = num_traits::cast::<S::Elem, f64>(*elem).unwrap_or(0.0);
399 if val < 0.0 {
400 return Err(TransformError::InvalidInput(
401 "NMF requires non-negative input data".to_string(),
402 ));
403 }
404 }
405
406 let v = x.mapv(|x| num_traits::cast::<S::Elem, f64>(x).unwrap_or(0.0));
408
409 let (n_samples, n_features) = (v.shape()[0], v.shape()[1]);
410
411 if self.n_components > n_features.min(n_samples) {
412 return Err(TransformError::InvalidInput(format!(
413 "n_components={} must be <= min(n_samples={}, n_features={})",
414 self.n_components, n_samples, n_features
415 )));
416 }
417
418 let (mut w, mut h) = self.initialize_matrices(&v)?;
420
421 let mut prev_error = self.frobenius_loss(&v, &w, &h);
422 let mut n_iter = 0;
423
424 for iter in 0..self.max_iter {
426 if self.solver == "mu" {
428 h = self.update_h(&v, &w, &h);
429 w = self.update_w(&v, &w, &h);
430 } else if self.solver == "cd" {
431 h = self.update_h_cd(&v, &w, &h);
432 w = self.update_w_cd(&v, &w, &h);
433 } else {
434 return Err(TransformError::InvalidInput(format!(
435 "Unknown solver '{}'. Supported solvers: 'mu', 'cd'",
436 self.solver
437 )));
438 }
439
440 let error = self.frobenius_loss(&v, &w, &h);
442
443 if (prev_error - error).abs() / prev_error.max(1e-10) < self.tol {
445 n_iter = iter + 1;
446 break;
447 }
448
449 prev_error = error;
450 n_iter = iter + 1;
451 }
452
453 self.components = Some(h);
454 self.coefficients = Some(w);
455 self.reconstruction_err = Some(prev_error);
456 self.n_iter = Some(n_iter);
457
458 Ok(())
459 }
460
461 pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
469 where
470 S: Data,
471 S::Elem: Float + NumCast,
472 {
473 if self.components.is_none() {
474 return Err(TransformError::TransformationError(
475 "NMF model has not been fitted".to_string(),
476 ));
477 }
478
479 for elem in x.iter() {
481 let val = num_traits::cast::<S::Elem, f64>(*elem).unwrap_or(0.0);
482 if val < 0.0 {
483 return Err(TransformError::InvalidInput(
484 "NMF requires non-negative input data".to_string(),
485 ));
486 }
487 }
488
489 let v = x.mapv(|x| num_traits::cast::<S::Elem, f64>(x).unwrap_or(0.0));
491
492 let h = self.components.as_ref().unwrap();
493 let n_samples = v.shape()[0];
494
495 let mut rng = rand::rng();
497
498 let scale = (v.mean().unwrap() / self.n_components as f64).sqrt();
499 let mut w = Array2::zeros((n_samples, self.n_components));
500
501 for i in 0..n_samples {
502 for j in 0..self.n_components {
503 w[[i, j]] = rng.random::<f64>() * scale;
504 }
505 }
506
507 for _ in 0..self.max_iter {
509 w = self.update_w(&v, &w, h);
510 }
511
512 Ok(w)
513 }
514
515 pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
523 where
524 S: Data,
525 S::Elem: Float + NumCast,
526 {
527 self.fit(x)?;
528 Ok(self.coefficients.as_ref().unwrap().clone())
529 }
530
531 pub fn components(&self) -> Option<&Array2<f64>> {
533 self.components.as_ref()
534 }
535
536 pub fn coefficients(&self) -> Option<&Array2<f64>> {
538 self.coefficients.as_ref()
539 }
540
541 pub fn reconstruction_error(&self) -> Option<f64> {
543 self.reconstruction_err
544 }
545
546 pub fn n_iterations(&self) -> Option<usize> {
548 self.n_iter
549 }
550
551 pub fn inverse_transform(&self, w: &Array2<f64>) -> Result<Array2<f64>> {
553 if self.components.is_none() {
554 return Err(TransformError::TransformationError(
555 "NMF model has not been fitted".to_string(),
556 ));
557 }
558
559 let h = self.components.as_ref().unwrap();
560 Ok(w.dot(h))
561 }
562}
563
564#[cfg(test)]
565mod tests {
566 use super::*;
567 use ndarray::Array;
568
569 #[test]
570 fn test_nmf_basic() {
571 let x = Array::from_shape_vec(
573 (6, 4),
574 vec![
575 1.0, 2.0, 3.0, 4.0, 2.0, 4.0, 6.0, 8.0, 3.0, 6.0, 9.0, 12.0, 4.0, 8.0, 12.0, 16.0,
576 5.0, 10.0, 15.0, 20.0, 6.0, 12.0, 18.0, 24.0,
577 ],
578 )
579 .unwrap();
580
581 let mut nmf = NMF::new(2).with_max_iter(100).with_random_state(42);
582
583 let w = nmf.fit_transform(&x).unwrap();
584
585 assert_eq!(w.shape(), &[6, 2]);
587
588 for val in w.iter() {
590 assert!(*val >= 0.0);
591 }
592
593 let h = nmf.components().unwrap();
595 assert_eq!(h.shape(), &[2, 4]);
596
597 for val in h.iter() {
598 assert!(*val >= 0.0);
599 }
600
601 let x_reconstructed = nmf.inverse_transform(&w).unwrap();
603 assert_eq!(x_reconstructed.shape(), x.shape());
604 }
605
606 #[test]
607 fn test_nmf_regularization() {
608 let x = Array2::<f64>::eye(10) + 0.1; let mut nmf = NMF::new(3).with_regularization(0.1, 0.5).with_max_iter(50);
611
612 let result = nmf.fit_transform(&x);
613 assert!(result.is_ok());
614
615 let w = result.unwrap();
616 assert_eq!(w.shape(), &[10, 3]);
617 }
618
619 #[test]
620 fn test_nmf_negative_input() {
621 let x = Array::from_shape_vec((3, 3), vec![1.0, 2.0, 3.0, -1.0, 5.0, 6.0, 7.0, 8.0, 9.0])
622 .unwrap();
623
624 let mut nmf = NMF::new(2);
625 let result = nmf.fit(&x);
626
627 assert!(result.is_err());
628 if let Err(e) = result {
629 assert!(e
630 .to_string()
631 .contains("NMF requires non-negative input data"));
632 }
633 }
634
635 #[test]
636 fn test_nmf_coordinate_descent() {
637 let x = Array::from_shape_vec(
639 (6, 4),
640 vec![
641 1.0, 2.0, 3.0, 4.0, 2.0, 4.0, 6.0, 8.0, 3.0, 6.0, 9.0, 12.0, 4.0, 8.0, 12.0, 16.0,
642 5.0, 10.0, 15.0, 20.0, 6.0, 12.0, 18.0, 24.0,
643 ],
644 )
645 .unwrap();
646
647 let mut nmf_cd = NMF::new(2)
648 .with_solver("cd")
649 .with_max_iter(100)
650 .with_random_state(42);
651
652 let w_cd = nmf_cd.fit_transform(&x).unwrap();
653
654 assert_eq!(w_cd.shape(), &[6, 2]);
656
657 for val in w_cd.iter() {
659 assert!(*val >= 0.0);
660 }
661
662 let h_cd = nmf_cd.components().unwrap();
664 assert_eq!(h_cd.shape(), &[2, 4]);
665
666 for val in h_cd.iter() {
667 assert!(*val >= 0.0);
668 }
669
670 let x_reconstructed = nmf_cd.inverse_transform(&w_cd).unwrap();
672 assert_eq!(x_reconstructed.shape(), x.shape());
673
674 let mut nmf_mu = NMF::new(2)
676 .with_solver("mu")
677 .with_max_iter(100)
678 .with_random_state(42);
679
680 let _w_mu = nmf_mu.fit_transform(&x).unwrap();
681
682 assert!(nmf_cd.reconstruction_error().unwrap() >= 0.0);
684 assert!(nmf_mu.reconstruction_error().unwrap() >= 0.0);
685 }
686
687 #[test]
688 fn test_nmf_invalid_solver() {
689 let x = Array2::<f64>::eye(3) + 0.1;
690 let mut nmf = NMF::new(2).with_solver("invalid");
691
692 let result = nmf.fit(&x);
693 assert!(result.is_err());
694 assert!(result.unwrap_err().to_string().contains("Unknown solver"));
695 }
696}