1use scirs2_core::ndarray::{Array2, ArrayBase, Data, Ix2};
8use scirs2_core::numeric::{Float, NumCast};
9use scirs2_core::random::Rng;
10
11use crate::error::{Result, TransformError};
12
13#[derive(Debug, Clone)]
19pub struct NMF {
20 n_components: usize,
22 init: String,
24 solver: String,
26 beta_loss: f64,
28 max_iter: usize,
30 tol: f64,
32 random_state: Option<u64>,
34 alpha: f64,
36 l1_ratio: f64,
38 components: Option<Array2<f64>>,
40 coefficients: Option<Array2<f64>>,
42 reconstruction_err: Option<f64>,
44 n_iter: Option<usize>,
46}
47
48impl NMF {
49 pub fn new(ncomponents: usize) -> Self {
54 NMF {
55 n_components: ncomponents,
56 init: "random".to_string(),
57 solver: "mu".to_string(),
58 beta_loss: 2.0, max_iter: 200,
60 tol: 1e-4,
61 random_state: None,
62 alpha: 0.0,
63 l1_ratio: 0.0,
64 components: None,
65 coefficients: None,
66 reconstruction_err: None,
67 n_iter: None,
68 }
69 }
70
71 pub fn with_init(mut self, init: &str) -> Self {
73 self.init = init.to_string();
74 self
75 }
76
77 pub fn with_solver(mut self, solver: &str) -> Self {
79 self.solver = solver.to_string();
80 self
81 }
82
83 pub fn with_beta_loss(mut self, beta: f64) -> Self {
85 self.beta_loss = beta;
86 self
87 }
88
89 pub fn with_max_iter(mut self, maxiter: usize) -> Self {
91 self.max_iter = maxiter;
92 self
93 }
94
95 pub fn with_tolerance(mut self, tol: f64) -> Self {
97 self.tol = tol;
98 self
99 }
100
101 pub fn with_random_state(mut self, seed: u64) -> Self {
103 self.random_state = Some(seed);
104 self
105 }
106
107 pub fn with_regularization(mut self, alpha: f64, l1ratio: f64) -> Self {
109 self.alpha = alpha;
110 self.l1_ratio = l1ratio;
111 self
112 }
113
114 fn random_initialization(&self, v: &Array2<f64>) -> (Array2<f64>, Array2<f64>) {
116 let (n_samples, n_features) = (v.shape()[0], v.shape()[1]);
117 let mut rng = scirs2_core::random::rng();
118
119 let scale = (v.mean().unwrap() / self.n_components as f64).sqrt();
120
121 let mut w = Array2::zeros((n_samples, self.n_components));
122 let mut h = Array2::zeros((self.n_components, n_features));
123
124 for i in 0..n_samples {
125 for j in 0..self.n_components {
126 w[[i, j]] = rng.random::<f64>() * scale;
127 }
128 }
129
130 for i in 0..self.n_components {
131 for j in 0..n_features {
132 h[[i, j]] = rng.random::<f64>() * scale;
133 }
134 }
135
136 (w, h)
137 }
138
139 fn nndsvd_initialization(&self, v: &Array2<f64>) -> Result<(Array2<f64>, Array2<f64>)> {
141 let (n_samples, n_features) = (v.shape()[0], v.shape()[1]);
142
143 let (u, s, vt) = match scirs2_linalg::svd::<f64>(&v.view(), true, None) {
145 Ok(result) => result,
146 Err(e) => return Err(TransformError::LinalgError(e)),
147 };
148
149 let mut w = Array2::zeros((n_samples, self.n_components));
150 let mut h = Array2::zeros((self.n_components, n_features));
151
152 for j in 0..self.n_components {
154 let x = u.column(j);
155 let y = vt.row(j);
156
157 let x_pos = x.mapv(|v| v.max(0.0));
159 let x_neg = x.mapv(|v| (-v).max(0.0));
160 let y_pos = y.mapv(|v| v.max(0.0));
161 let y_neg = y.mapv(|v| (-v).max(0.0));
162
163 let x_pos_norm = x_pos.dot(&x_pos).sqrt();
164 let x_neg_norm = x_neg.dot(&x_neg).sqrt();
165 let y_pos_norm = y_pos.dot(&y_pos).sqrt();
166 let y_neg_norm = y_neg.dot(&y_neg).sqrt();
167
168 let m_pos = x_pos_norm * y_pos_norm;
169 let m_neg = x_neg_norm * y_neg_norm;
170
171 if m_pos > m_neg {
172 for i in 0..n_samples {
173 w[[i, j]] = (s[j].sqrt() * x_pos[i] / x_pos_norm).max(0.0);
174 }
175 for i in 0..n_features {
176 h[[j, i]] = (s[j].sqrt() * y_pos[i] / y_pos_norm).max(0.0);
177 }
178 } else {
179 for i in 0..n_samples {
180 w[[i, j]] = (s[j].sqrt() * x_neg[i] / x_neg_norm).max(0.0);
181 }
182 for i in 0..n_features {
183 h[[j, i]] = (s[j].sqrt() * y_neg[i] / y_neg_norm).max(0.0);
184 }
185 }
186 }
187
188 Ok((w, h))
189 }
190
191 fn initialize_matrices(&self, v: &Array2<f64>) -> Result<(Array2<f64>, Array2<f64>)> {
193 match self.init.as_str() {
194 "random" => Ok(self.random_initialization(v)),
195 "nndsvd" => self.nndsvd_initialization(v),
196 _ => Ok(self.random_initialization(v)),
197 }
198 }
199
200 fn frobenius_loss(&self, v: &Array2<f64>, w: &Array2<f64>, h: &Array2<f64>) -> f64 {
202 let wh = w.dot(h);
203 let diff = v - &wh;
204 diff.mapv(|x| x * x).sum().sqrt()
205 }
206
207 fn update_w(&self, v: &Array2<f64>, w: &Array2<f64>, h: &Array2<f64>) -> Array2<f64> {
209 let eps = 1e-10;
210 let wh = w.dot(h);
211
212 let numerator = v.dot(&h.t());
214
215 let mut denominator = wh.dot(&h.t());
217
218 if self.alpha > 0.0 && self.l1_ratio < 1.0 {
220 let l2_reg = self.alpha * (1.0 - self.l1_ratio);
221 denominator = &denominator + &(w * l2_reg);
222 }
223
224 if self.alpha > 0.0 && self.l1_ratio > 0.0 {
226 let l1_reg = self.alpha * self.l1_ratio;
227 denominator = denominator.mapv(|x| x + l1_reg);
228 }
229
230 let mut w_new = w * &(numerator / (denominator + eps));
232
233 w_new.mapv_inplace(|x| x.max(eps));
235
236 w_new
237 }
238
239 fn update_h(&self, v: &Array2<f64>, w: &Array2<f64>, h: &Array2<f64>) -> Array2<f64> {
241 let eps = 1e-10;
242 let wh = w.dot(h);
243
244 let numerator = w.t().dot(v);
246
247 let mut denominator = w.t().dot(&wh);
249
250 if self.alpha > 0.0 && self.l1_ratio < 1.0 {
252 let l2_reg = self.alpha * (1.0 - self.l1_ratio);
253 denominator = &denominator + &(h * l2_reg);
254 }
255
256 if self.alpha > 0.0 && self.l1_ratio > 0.0 {
258 let l1_reg = self.alpha * self.l1_ratio;
259 denominator = denominator.mapv(|x| x + l1_reg);
260 }
261
262 let mut h_new = h * &(numerator / (denominator + eps));
264
265 h_new.mapv_inplace(|x| x.max(eps));
267
268 h_new
269 }
270
271 fn update_w_cd(&self, v: &Array2<f64>, w: &Array2<f64>, h: &Array2<f64>) -> Array2<f64> {
273 let eps = 1e-10;
274 let (n_samples, n_components) = w.dim();
275 let mut w_new = w.clone();
276
277 let hht = h.dot(&h.t());
279
280 for i in 0..n_samples {
281 for j in 0..n_components {
282 let mut numerator = 0.0;
284 let mut denominator = hht[[j, j]];
285
286 for k in 0..h.ncols() {
288 numerator += v[[i, k]] * h[[j, k]];
289 }
290
291 for k in 0..n_components {
293 if k != j {
294 numerator -= w_new[[i, k]] * hht[[k, j]];
295 }
296 }
297
298 if self.alpha > 0.0 {
300 if self.l1_ratio > 0.0 {
301 let l1_penalty = self.alpha * self.l1_ratio;
303 numerator -= l1_penalty;
304 }
305 if self.l1_ratio < 1.0 {
306 let l2_penalty = self.alpha * (1.0 - self.l1_ratio);
308 denominator += l2_penalty;
309 numerator -= l2_penalty * w_new[[i, j]];
310 }
311 }
312
313 let new_val = if denominator > eps {
315 (numerator / denominator).max(eps)
316 } else {
317 eps
318 };
319
320 w_new[[i, j]] = new_val;
321 }
322 }
323
324 w_new
325 }
326
327 fn update_h_cd(&self, v: &Array2<f64>, w: &Array2<f64>, h: &Array2<f64>) -> Array2<f64> {
329 let eps = 1e-10;
330 let (n_components, n_features) = h.dim();
331 let mut h_new = h.clone();
332
333 let wtw = w.t().dot(w);
335
336 for i in 0..n_components {
337 for j in 0..n_features {
338 let mut numerator = 0.0;
340 let mut denominator = wtw[[i, i]];
341
342 for k in 0..w.nrows() {
344 numerator += w[[k, i]] * v[[k, j]];
345 }
346
347 for k in 0..n_components {
349 if k != i {
350 numerator -= wtw[[i, k]] * h_new[[k, j]];
351 }
352 }
353
354 if self.alpha > 0.0 {
356 if self.l1_ratio > 0.0 {
357 let l1_penalty = self.alpha * self.l1_ratio;
359 numerator -= l1_penalty;
360 }
361 if self.l1_ratio < 1.0 {
362 let l2_penalty = self.alpha * (1.0 - self.l1_ratio);
364 denominator += l2_penalty;
365 numerator -= l2_penalty * h_new[[i, j]];
366 }
367 }
368
369 let new_val = if denominator > eps {
371 (numerator / denominator).max(eps)
372 } else {
373 eps
374 };
375
376 h_new[[i, j]] = new_val;
377 }
378 }
379
380 h_new
381 }
382
383 pub fn fit<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<()>
391 where
392 S: Data,
393 S::Elem: Float + NumCast,
394 {
395 for elem in x.iter() {
397 let val = NumCast::from(*elem).unwrap_or(0.0);
398 if val < 0.0 {
399 return Err(TransformError::InvalidInput(
400 "NMF requires non-negative input data".to_string(),
401 ));
402 }
403 }
404
405 let v = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
407
408 let (n_samples, n_features) = (v.shape()[0], v.shape()[1]);
409
410 if self.n_components > n_features.min(n_samples) {
411 return Err(TransformError::InvalidInput(format!(
412 "n_components={} must be <= min(n_samples={}, n_features={})",
413 self.n_components, n_samples, n_features
414 )));
415 }
416
417 let (mut w, mut h) = self.initialize_matrices(&v)?;
419
420 let mut prev_error = self.frobenius_loss(&v, &w, &h);
421 let mut n_iter = 0;
422
423 for iter in 0..self.max_iter {
425 if self.solver == "mu" {
427 h = self.update_h(&v, &w, &h);
428 w = self.update_w(&v, &w, &h);
429 } else if self.solver == "cd" {
430 h = self.update_h_cd(&v, &w, &h);
431 w = self.update_w_cd(&v, &w, &h);
432 } else {
433 return Err(TransformError::InvalidInput(format!(
434 "Unknown solver '{}'. Supported solvers: 'mu', 'cd'",
435 self.solver
436 )));
437 }
438
439 let error = self.frobenius_loss(&v, &w, &h);
441
442 if (prev_error - error).abs() / prev_error.max(1e-10) < self.tol {
444 n_iter = iter + 1;
445 break;
446 }
447
448 prev_error = error;
449 n_iter = iter + 1;
450 }
451
452 self.components = Some(h);
453 self.coefficients = Some(w);
454 self.reconstruction_err = Some(prev_error);
455 self.n_iter = Some(n_iter);
456
457 Ok(())
458 }
459
460 pub fn transform<S>(&self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
468 where
469 S: Data,
470 S::Elem: Float + NumCast,
471 {
472 if self.components.is_none() {
473 return Err(TransformError::TransformationError(
474 "NMF model has not been fitted".to_string(),
475 ));
476 }
477
478 for elem in x.iter() {
480 let val = NumCast::from(*elem).unwrap_or(0.0);
481 if val < 0.0 {
482 return Err(TransformError::InvalidInput(
483 "NMF requires non-negative input data".to_string(),
484 ));
485 }
486 }
487
488 let v = x.mapv(|x| NumCast::from(x).unwrap_or(0.0));
490
491 let h = self.components.as_ref().unwrap();
492 let n_samples = v.shape()[0];
493
494 let mut rng = scirs2_core::random::rng();
496
497 let scale = (v.mean().unwrap() / self.n_components as f64).sqrt();
498 let mut w = Array2::zeros((n_samples, self.n_components));
499
500 for i in 0..n_samples {
501 for j in 0..self.n_components {
502 w[[i, j]] = rng.random::<f64>() * scale;
503 }
504 }
505
506 for _ in 0..self.max_iter {
508 w = self.update_w(&v, &w, h);
509 }
510
511 Ok(w)
512 }
513
514 pub fn fit_transform<S>(&mut self, x: &ArrayBase<S, Ix2>) -> Result<Array2<f64>>
522 where
523 S: Data,
524 S::Elem: Float + NumCast,
525 {
526 self.fit(x)?;
527 Ok(self.coefficients.as_ref().unwrap().clone())
528 }
529
530 pub fn components(&self) -> Option<&Array2<f64>> {
532 self.components.as_ref()
533 }
534
535 pub fn coefficients(&self) -> Option<&Array2<f64>> {
537 self.coefficients.as_ref()
538 }
539
540 pub fn reconstruction_error(&self) -> Option<f64> {
542 self.reconstruction_err
543 }
544
545 pub fn n_iterations(&self) -> Option<usize> {
547 self.n_iter
548 }
549
550 pub fn inverse_transform(&self, w: &Array2<f64>) -> Result<Array2<f64>> {
552 if self.components.is_none() {
553 return Err(TransformError::TransformationError(
554 "NMF model has not been fitted".to_string(),
555 ));
556 }
557
558 let h = self.components.as_ref().unwrap();
559 Ok(w.dot(h))
560 }
561}
562
563#[cfg(test)]
564mod tests {
565 use super::*;
566 use scirs2_core::ndarray::Array;
567
568 #[test]
569 fn test_nmf_basic() {
570 let x = Array::from_shape_vec(
572 (6, 4),
573 vec![
574 1.0, 2.0, 3.0, 4.0, 2.0, 4.0, 6.0, 8.0, 3.0, 6.0, 9.0, 12.0, 4.0, 8.0, 12.0, 16.0,
575 5.0, 10.0, 15.0, 20.0, 6.0, 12.0, 18.0, 24.0,
576 ],
577 )
578 .unwrap();
579
580 let mut nmf = NMF::new(2).with_max_iter(100).with_random_state(42);
581
582 let w = nmf.fit_transform(&x).unwrap();
583
584 assert_eq!(w.shape(), &[6, 2]);
586
587 for val in w.iter() {
589 assert!(*val >= 0.0);
590 }
591
592 let h = nmf.components().unwrap();
594 assert_eq!(h.shape(), &[2, 4]);
595
596 for val in h.iter() {
597 assert!(*val >= 0.0);
598 }
599
600 let x_reconstructed = nmf.inverse_transform(&w).unwrap();
602 assert_eq!(x_reconstructed.shape(), x.shape());
603 }
604
605 #[test]
606 fn test_nmf_regularization() {
607 let x = Array2::<f64>::eye(10) + 0.1; let mut nmf = NMF::new(3).with_regularization(0.1, 0.5).with_max_iter(50);
610
611 let result = nmf.fit_transform(&x);
612 assert!(result.is_ok());
613
614 let w = result.unwrap();
615 assert_eq!(w.shape(), &[10, 3]);
616 }
617
618 #[test]
619 fn test_nmf_negative_input() {
620 let x = Array::from_shape_vec((3, 3), vec![1.0, 2.0, 3.0, -1.0, 5.0, 6.0, 7.0, 8.0, 9.0])
621 .unwrap();
622
623 let mut nmf = NMF::new(2);
624 let result = nmf.fit(&x);
625
626 assert!(result.is_err());
627 if let Err(e) = result {
628 assert!(e
629 .to_string()
630 .contains("NMF requires non-negative input data"));
631 }
632 }
633
634 #[test]
635 fn test_nmf_coordinate_descent() {
636 let x = Array::from_shape_vec(
638 (6, 4),
639 vec![
640 1.0, 2.0, 3.0, 4.0, 2.0, 4.0, 6.0, 8.0, 3.0, 6.0, 9.0, 12.0, 4.0, 8.0, 12.0, 16.0,
641 5.0, 10.0, 15.0, 20.0, 6.0, 12.0, 18.0, 24.0,
642 ],
643 )
644 .unwrap();
645
646 let mut nmf_cd = NMF::new(2)
647 .with_solver("cd")
648 .with_max_iter(100)
649 .with_random_state(42);
650
651 let w_cd = nmf_cd.fit_transform(&x).unwrap();
652
653 assert_eq!(w_cd.shape(), &[6, 2]);
655
656 for val in w_cd.iter() {
658 assert!(*val >= 0.0);
659 }
660
661 let h_cd = nmf_cd.components().unwrap();
663 assert_eq!(h_cd.shape(), &[2, 4]);
664
665 for val in h_cd.iter() {
666 assert!(*val >= 0.0);
667 }
668
669 let x_reconstructed = nmf_cd.inverse_transform(&w_cd).unwrap();
671 assert_eq!(x_reconstructed.shape(), x.shape());
672
673 let mut nmf_mu = NMF::new(2)
675 .with_solver("mu")
676 .with_max_iter(100)
677 .with_random_state(42);
678
679 let _w_mu = nmf_mu.fit_transform(&x).unwrap();
680
681 assert!(nmf_cd.reconstruction_error().unwrap() >= 0.0);
683 assert!(nmf_mu.reconstruction_error().unwrap() >= 0.0);
684 }
685
686 #[test]
687 fn test_nmf_invalid_solver() {
688 let x = Array2::<f64>::eye(3) + 0.1;
689 let mut nmf = NMF::new(2).with_solver("invalid");
690
691 let result = nmf.fit(&x);
692 assert!(result.is_err());
693 assert!(result.unwrap_err().to_string().contains("Unknown solver"));
694 }
695}