1use crate::error::{Result, TextError};
44use scirs2_core::ndarray::{Array1, Array2};
45use scirs2_core::random::{rngs::StdRng, Rng, RngExt, SeedableRng};
46use std::collections::{HashMap, HashSet};
47
48#[derive(Debug, Clone)]
54pub struct GibbsLdaConfig {
55 pub n_topics: usize,
57 pub alpha: f64,
59 pub beta: f64,
61 pub n_iterations: usize,
63 pub burn_in: usize,
65 pub seed: Option<u64>,
67}
68
69impl Default for GibbsLdaConfig {
70 fn default() -> Self {
71 Self {
72 n_topics: 10,
73 alpha: 0.1,
74 beta: 0.01,
75 n_iterations: 500,
76 burn_in: 50,
77 seed: None,
78 }
79 }
80}
81
82#[derive(Debug)]
87pub struct GibbsLda {
88 config: GibbsLdaConfig,
89 vocab: HashMap<String, usize>,
91 rev_vocab: Vec<String>,
93 topic_assignments: Vec<Vec<usize>>,
95 n_dk: Vec<Vec<usize>>,
97 n_kw: Vec<Vec<usize>>,
99 n_k: Vec<usize>,
101 doc_lengths: Vec<usize>,
103 doc_words: Vec<Vec<usize>>,
105 fitted: bool,
107}
108
109impl GibbsLda {
110 pub fn new(config: GibbsLdaConfig) -> Self {
112 Self {
113 config,
114 vocab: HashMap::new(),
115 rev_vocab: Vec::new(),
116 topic_assignments: Vec::new(),
117 n_dk: Vec::new(),
118 n_kw: Vec::new(),
119 n_k: Vec::new(),
120 doc_lengths: Vec::new(),
121 doc_words: Vec::new(),
122 fitted: false,
123 }
124 }
125
126 pub fn fit(&mut self, documents: &[Vec<&str>]) -> Result<()> {
128 if documents.is_empty() {
129 return Err(TextError::InvalidInput(
130 "Cannot fit LDA on empty corpus".to_string(),
131 ));
132 }
133
134 let n_topics = self.config.n_topics;
135 if n_topics == 0 {
136 return Err(TextError::InvalidInput(
137 "Number of topics must be > 0".to_string(),
138 ));
139 }
140
141 self.vocab.clear();
143 self.rev_vocab.clear();
144 let mut word_set: HashSet<String> = HashSet::new();
145 for doc in documents {
146 for &word in doc {
147 word_set.insert(word.to_string());
148 }
149 }
150 let mut sorted_words: Vec<String> = word_set.into_iter().collect();
151 sorted_words.sort();
152 for (idx, word) in sorted_words.iter().enumerate() {
153 self.vocab.insert(word.clone(), idx);
154 }
155 self.rev_vocab = sorted_words;
156 let n_vocab = self.rev_vocab.len();
157
158 if n_vocab == 0 {
159 return Err(TextError::InvalidInput(
160 "Empty vocabulary after tokenization".to_string(),
161 ));
162 }
163
164 let n_docs = documents.len();
165
166 self.doc_words = documents
168 .iter()
169 .map(|doc| {
170 doc.iter()
171 .filter_map(|w| self.vocab.get(*w).copied())
172 .collect()
173 })
174 .collect();
175
176 self.doc_lengths = self.doc_words.iter().map(|d| d.len()).collect();
177
178 self.n_dk = vec![vec![0; n_topics]; n_docs];
180 self.n_kw = vec![vec![0; n_vocab]; n_topics];
181 self.n_k = vec![0; n_topics];
182 self.topic_assignments = Vec::with_capacity(n_docs);
183
184 let mut rng = match self.config.seed {
186 Some(seed) => StdRng::seed_from_u64(seed),
187 None => StdRng::seed_from_u64(42),
188 };
189
190 for d in 0..n_docs {
191 let mut doc_topics = Vec::with_capacity(self.doc_words[d].len());
192 for &w in &self.doc_words[d] {
193 let k = (rng.random::<f64>() * n_topics as f64) as usize % n_topics;
194 doc_topics.push(k);
195 self.n_dk[d][k] += 1;
196 self.n_kw[k][w] += 1;
197 self.n_k[k] += 1;
198 }
199 self.topic_assignments.push(doc_topics);
200 }
201
202 let alpha = self.config.alpha;
204 let beta = self.config.beta;
205 let beta_sum = beta * n_vocab as f64;
206
207 for _iter in 0..self.config.n_iterations {
208 for d in 0..n_docs {
209 let n_words_d = self.doc_words[d].len();
210 for i in 0..n_words_d {
211 let w = self.doc_words[d][i];
212 let old_k = self.topic_assignments[d][i];
213
214 self.n_dk[d][old_k] -= 1;
216 self.n_kw[old_k][w] -= 1;
217 self.n_k[old_k] -= 1;
218
219 let mut probs = vec![0.0f64; n_topics];
221 for k in 0..n_topics {
222 probs[k] = (self.n_dk[d][k] as f64 + alpha)
223 * (self.n_kw[k][w] as f64 + beta)
224 / (self.n_k[k] as f64 + beta_sum);
225 }
226
227 let total: f64 = probs.iter().sum();
229 if total < 1e-15 {
230 let new_k = (rng.random::<f64>() * n_topics as f64) as usize % n_topics;
232 self.topic_assignments[d][i] = new_k;
233 self.n_dk[d][new_k] += 1;
234 self.n_kw[new_k][w] += 1;
235 self.n_k[new_k] += 1;
236 continue;
237 }
238
239 let threshold = rng.random::<f64>() * total;
240 let mut cumsum = 0.0;
241 let mut new_k = n_topics - 1;
242 for k in 0..n_topics {
243 cumsum += probs[k];
244 if cumsum >= threshold {
245 new_k = k;
246 break;
247 }
248 }
249
250 self.topic_assignments[d][i] = new_k;
252 self.n_dk[d][new_k] += 1;
253 self.n_kw[new_k][w] += 1;
254 self.n_k[new_k] += 1;
255 }
256 }
257 }
258
259 self.fitted = true;
260 Ok(())
261 }
262
263 pub fn topic_word_distribution(&self, topic: usize) -> Result<Array1<f64>> {
265 if !self.fitted {
266 return Err(TextError::ModelNotFitted("LDA not fitted".to_string()));
267 }
268 if topic >= self.config.n_topics {
269 return Err(TextError::InvalidInput(format!(
270 "Topic {} out of range ({})",
271 topic, self.config.n_topics
272 )));
273 }
274
275 let n_vocab = self.rev_vocab.len();
276 let beta = self.config.beta;
277 let beta_sum = beta * n_vocab as f64;
278 let total = self.n_k[topic] as f64 + beta_sum;
279
280 let mut dist = Array1::<f64>::zeros(n_vocab);
281 for w in 0..n_vocab {
282 dist[w] = (self.n_kw[topic][w] as f64 + beta) / total;
283 }
284 Ok(dist)
285 }
286
287 pub fn doc_topic_distribution(&self, doc: usize) -> Result<Array1<f64>> {
289 if !self.fitted {
290 return Err(TextError::ModelNotFitted("LDA not fitted".to_string()));
291 }
292 if doc >= self.n_dk.len() {
293 return Err(TextError::InvalidInput(format!(
294 "Document {} out of range ({})",
295 doc,
296 self.n_dk.len()
297 )));
298 }
299
300 let n_topics = self.config.n_topics;
301 let alpha = self.config.alpha;
302 let total = self.doc_lengths[doc] as f64 + alpha * n_topics as f64;
303
304 let mut dist = Array1::<f64>::zeros(n_topics);
305 if total < 1e-15 {
306 let uniform = 1.0 / n_topics as f64;
308 for k in 0..n_topics {
309 dist[k] = uniform;
310 }
311 } else {
312 for k in 0..n_topics {
313 dist[k] = (self.n_dk[doc][k] as f64 + alpha) / total;
314 }
315 }
316 Ok(dist)
317 }
318
319 pub fn doc_topic_matrix(&self) -> Result<Array2<f64>> {
321 if !self.fitted {
322 return Err(TextError::ModelNotFitted("LDA not fitted".to_string()));
323 }
324
325 let n_docs = self.n_dk.len();
326 let n_topics = self.config.n_topics;
327 let mut matrix = Array2::<f64>::zeros((n_docs, n_topics));
328 for d in 0..n_docs {
329 let dist = self.doc_topic_distribution(d)?;
330 for k in 0..n_topics {
331 matrix[[d, k]] = dist[k];
332 }
333 }
334 Ok(matrix)
335 }
336
337 pub fn top_words(&self, n_words: usize) -> Vec<Vec<(String, f64)>> {
339 let n_topics = self.config.n_topics;
340 let mut result = Vec::with_capacity(n_topics);
341
342 for k in 0..n_topics {
343 let dist = match self.topic_word_distribution(k) {
344 Ok(d) => d,
345 Err(_) => {
346 result.push(Vec::new());
347 continue;
348 }
349 };
350
351 let mut word_probs: Vec<(usize, f64)> =
352 dist.iter().enumerate().map(|(i, &p)| (i, p)).collect();
353 word_probs.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
354
355 let top: Vec<(String, f64)> = word_probs
356 .iter()
357 .take(n_words.min(self.rev_vocab.len()))
358 .map(|(idx, prob)| (self.rev_vocab[*idx].clone(), *prob))
359 .collect();
360
361 result.push(top);
362 }
363 result
364 }
365
366 pub fn vocabulary(&self) -> &HashMap<String, usize> {
368 &self.vocab
369 }
370
371 pub fn n_topics(&self) -> usize {
373 self.config.n_topics
374 }
375
376 pub fn is_fitted(&self) -> bool {
378 self.fitted
379 }
380}
381
382#[derive(Debug, Clone)]
388pub struct NmfConfig {
389 pub n_topics: usize,
391 pub max_iter: usize,
393 pub tolerance: f64,
395 pub seed: u64,
397}
398
399impl Default for NmfConfig {
400 fn default() -> Self {
401 Self {
402 n_topics: 10,
403 max_iter: 200,
404 tolerance: 1e-4,
405 seed: 42,
406 }
407 }
408}
409
410#[derive(Debug)]
418pub struct NmfTopicModel {
419 config: NmfConfig,
420 w: Option<Array2<f64>>,
422 h: Option<Array2<f64>>,
424 vocab: Vec<String>,
426 error_history: Vec<f64>,
428 fitted: bool,
430}
431
432impl NmfTopicModel {
433 pub fn new(config: NmfConfig) -> Self {
435 Self {
436 config,
437 w: None,
438 h: None,
439 vocab: Vec::new(),
440 error_history: Vec::new(),
441 fitted: false,
442 }
443 }
444
445 pub fn fit(&mut self, matrix: &Array2<f64>, vocabulary: &[String]) -> Result<()> {
449 let (n_docs, n_terms) = matrix.dim();
450 let n_topics = self.config.n_topics;
451
452 if n_docs == 0 || n_terms == 0 {
453 return Err(TextError::InvalidInput(
454 "Cannot fit NMF on empty matrix".to_string(),
455 ));
456 }
457 if n_topics > n_docs || n_topics > n_terms {
458 return Err(TextError::InvalidInput(format!(
459 "n_topics ({}) must not exceed matrix dimensions ({}, {})",
460 n_topics, n_docs, n_terms
461 )));
462 }
463
464 self.vocab = vocabulary.to_vec();
465
466 let mut rng = StdRng::seed_from_u64(self.config.seed);
468 let mut w = Array2::<f64>::zeros((n_docs, n_topics));
469 let mut h = Array2::<f64>::zeros((n_topics, n_terms));
470
471 let eps = 1e-10;
472 for elem in w.iter_mut() {
473 *elem = rng.random::<f64>() * 0.1 + eps;
474 }
475 for elem in h.iter_mut() {
476 *elem = rng.random::<f64>() * 0.1 + eps;
477 }
478
479 self.error_history.clear();
480
481 for _iter in 0..self.config.max_iter {
483 let wt_v = mat_mul_ata_b(&w, matrix);
485 let wt_w = mat_mul_ata_b(&w, &w);
486 let wt_w_h = mat_mul_ab(&wt_w, &h);
487
488 for i in 0..n_topics {
489 for j in 0..n_terms {
490 let denom = wt_w_h[[i, j]] + eps;
491 h[[i, j]] *= wt_v[[i, j]] / denom;
492 if h[[i, j]] < eps {
493 h[[i, j]] = eps;
494 }
495 }
496 }
497
498 let v_ht = mat_mul_abt(matrix, &h);
500 let w_h = mat_mul_ab(&w, &h);
501 let w_h_ht = mat_mul_abt(&w_h, &h);
502
503 for i in 0..n_docs {
504 for j in 0..n_topics {
505 let denom = w_h_ht[[i, j]] + eps;
506 w[[i, j]] *= v_ht[[i, j]] / denom;
507 if w[[i, j]] < eps {
508 w[[i, j]] = eps;
509 }
510 }
511 }
512
513 let wh = mat_mul_ab(&w, &h);
515 let mut error = 0.0;
516 for i in 0..n_docs {
517 for j in 0..n_terms {
518 let diff = matrix[[i, j]] - wh[[i, j]];
519 error += diff * diff;
520 }
521 }
522 error = error.sqrt();
523 self.error_history.push(error);
524
525 if self.error_history.len() >= 2 {
527 let prev = self.error_history[self.error_history.len() - 2];
528 if (prev - error).abs() < self.config.tolerance {
529 break;
530 }
531 }
532 }
533
534 self.w = Some(w);
535 self.h = Some(h);
536 self.fitted = true;
537 Ok(())
538 }
539
540 pub fn doc_topic_matrix(&self) -> Result<&Array2<f64>> {
542 self.w
543 .as_ref()
544 .ok_or_else(|| TextError::ModelNotFitted("NMF not fitted".to_string()))
545 }
546
547 pub fn topic_term_matrix(&self) -> Result<&Array2<f64>> {
549 self.h
550 .as_ref()
551 .ok_or_else(|| TextError::ModelNotFitted("NMF not fitted".to_string()))
552 }
553
554 pub fn top_words(&self, n_words: usize) -> Result<Vec<Vec<(String, f64)>>> {
556 let h = self.topic_term_matrix()?;
557 let n_topics = h.nrows();
558 let mut result = Vec::with_capacity(n_topics);
559
560 for k in 0..n_topics {
561 let row = h.row(k);
562 let mut word_scores: Vec<(usize, f64)> =
563 row.iter().enumerate().map(|(i, &v)| (i, v)).collect();
564 word_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
565
566 let top: Vec<(String, f64)> = word_scores
567 .iter()
568 .take(n_words.min(self.vocab.len()))
569 .filter_map(|(idx, score)| self.vocab.get(*idx).map(|w| (w.clone(), *score)))
570 .collect();
571 result.push(top);
572 }
573 Ok(result)
574 }
575
576 pub fn error_history(&self) -> &[f64] {
578 &self.error_history
579 }
580
581 pub fn is_fitted(&self) -> bool {
583 self.fitted
584 }
585}
586
587#[derive(Debug, Clone)]
595pub struct TopicCoherenceScorer {
596 window_size: usize,
598 epsilon: f64,
600}
601
602impl Default for TopicCoherenceScorer {
603 fn default() -> Self {
604 Self {
605 window_size: 10,
606 epsilon: 1e-12,
607 }
608 }
609}
610
611impl TopicCoherenceScorer {
612 pub fn new() -> Self {
614 Self::default()
615 }
616
617 pub fn with_window_size(mut self, size: usize) -> Self {
619 self.window_size = size;
620 self
621 }
622
623 pub fn cv_coherence(
627 &self,
628 topic_words: &[Vec<String>],
629 documents: &[Vec<String>],
630 ) -> Result<f64> {
631 if topic_words.is_empty() || documents.is_empty() {
632 return Err(TextError::InvalidInput(
633 "Topic words and documents must not be empty".to_string(),
634 ));
635 }
636
637 let n_docs = documents.len() as f64;
638
639 let doc_sets: Vec<HashSet<&String>> =
641 documents.iter().map(|doc| doc.iter().collect()).collect();
642
643 let mut topic_scores = Vec::with_capacity(topic_words.len());
644
645 for words in topic_words {
646 if words.len() < 2 {
647 topic_scores.push(0.0);
648 continue;
649 }
650
651 let mut npmi_sum = 0.0;
652 let mut pair_count = 0;
653
654 for i in 0..words.len() {
655 for j in (i + 1)..words.len() {
656 let wi = &words[i];
657 let wj = &words[j];
658
659 let df_i = doc_sets.iter().filter(|s| s.contains(wi)).count() as f64;
660 let df_j = doc_sets.iter().filter(|s| s.contains(wj)).count() as f64;
661 let df_ij = doc_sets
662 .iter()
663 .filter(|s| s.contains(wi) && s.contains(wj))
664 .count() as f64;
665
666 let p_i = (df_i + self.epsilon) / n_docs;
667 let p_j = (df_j + self.epsilon) / n_docs;
668 let p_ij = (df_ij + self.epsilon) / n_docs;
669
670 let pmi = (p_ij / (p_i * p_j)).ln();
672 let neg_log_p_ij = -(p_ij.ln());
673
674 let npmi = if neg_log_p_ij.abs() > self.epsilon {
675 pmi / neg_log_p_ij
676 } else {
677 0.0
678 };
679
680 npmi_sum += npmi;
681 pair_count += 1;
682 }
683 }
684
685 let score = if pair_count > 0 {
686 npmi_sum / pair_count as f64
687 } else {
688 0.0
689 };
690 topic_scores.push(score);
691 }
692
693 let avg = topic_scores.iter().sum::<f64>() / topic_scores.len() as f64;
694 Ok(avg)
695 }
696
697 pub fn umass_coherence(
701 &self,
702 topic_words: &[Vec<String>],
703 documents: &[Vec<String>],
704 ) -> Result<f64> {
705 if topic_words.is_empty() || documents.is_empty() {
706 return Err(TextError::InvalidInput(
707 "Topic words and documents must not be empty".to_string(),
708 ));
709 }
710
711 let doc_sets: Vec<HashSet<&String>> =
712 documents.iter().map(|doc| doc.iter().collect()).collect();
713
714 let mut topic_scores = Vec::with_capacity(topic_words.len());
715
716 for words in topic_words {
717 if words.len() < 2 {
718 topic_scores.push(0.0);
719 continue;
720 }
721
722 let mut score = 0.0;
723 let mut pair_count = 0;
724
725 for i in 1..words.len() {
726 for j in 0..i {
727 let wi = &words[i];
728 let wj = &words[j];
729
730 let df_j = doc_sets.iter().filter(|s| s.contains(wj)).count() as f64;
731 let df_ij = doc_sets
732 .iter()
733 .filter(|s| s.contains(wi) && s.contains(wj))
734 .count() as f64;
735
736 score += ((df_ij + self.epsilon) / (df_j + self.epsilon)).ln();
738 pair_count += 1;
739 }
740 }
741
742 let avg_score = if pair_count > 0 {
743 score / pair_count as f64
744 } else {
745 0.0
746 };
747 topic_scores.push(avg_score);
748 }
749
750 let avg = topic_scores.iter().sum::<f64>() / topic_scores.len() as f64;
751 Ok(avg)
752 }
753}
754
755pub fn select_n_topics(
776 documents: &[Vec<&str>],
777 min_topics: usize,
778 max_topics: usize,
779 n_iterations: usize,
780 seed: u64,
781) -> Result<(usize, Vec<(usize, f64)>)> {
782 if documents.is_empty() {
783 return Err(TextError::InvalidInput(
784 "Cannot select topics on empty corpus".to_string(),
785 ));
786 }
787 if min_topics == 0 || min_topics > max_topics {
788 return Err(TextError::InvalidInput(format!(
789 "Invalid topic range: {} to {}",
790 min_topics, max_topics
791 )));
792 }
793
794 let scorer = TopicCoherenceScorer::new();
795
796 let doc_strings: Vec<Vec<String>> = documents
798 .iter()
799 .map(|doc| doc.iter().map(|w| w.to_string()).collect())
800 .collect();
801
802 let mut scores: Vec<(usize, f64)> = Vec::new();
803 let mut best_k = min_topics;
804 let mut best_score = f64::NEG_INFINITY;
805
806 for k in min_topics..=max_topics {
807 let config = GibbsLdaConfig {
808 n_topics: k,
809 alpha: 50.0 / k as f64,
810 beta: 0.01,
811 n_iterations,
812 burn_in: n_iterations / 5,
813 seed: Some(seed),
814 };
815
816 let mut lda = GibbsLda::new(config);
817 lda.fit(documents)?;
818
819 let top_words = lda.top_words(10);
820 let topic_word_strs: Vec<Vec<String>> = top_words
821 .iter()
822 .map(|tw| tw.iter().map(|(w, _)| w.clone()).collect())
823 .collect();
824
825 let coherence = scorer.cv_coherence(&topic_word_strs, &doc_strings)?;
826 scores.push((k, coherence));
827
828 if coherence > best_score {
829 best_score = coherence;
830 best_k = k;
831 }
832 }
833
834 Ok((best_k, scores))
835}
836
837fn mat_mul_ab(a: &Array2<f64>, b: &Array2<f64>) -> Array2<f64> {
843 let (ar, ac) = a.dim();
844 let (_br, bc) = b.dim();
845 let mut result = Array2::<f64>::zeros((ar, bc));
846 for i in 0..ar {
847 for k in 0..ac {
848 let a_ik = a[[i, k]];
849 if a_ik.abs() < 1e-15 {
850 continue;
851 }
852 for j in 0..bc {
853 result[[i, j]] += a_ik * b[[k, j]];
854 }
855 }
856 }
857 result
858}
859
860fn mat_mul_ata_b(a: &Array2<f64>, b: &Array2<f64>) -> Array2<f64> {
862 let (ar, ac) = a.dim();
863 let (_br, bc) = b.dim();
864 let mut result = Array2::<f64>::zeros((ac, bc));
865 for k in 0..ar {
866 for i in 0..ac {
867 let a_ki = a[[k, i]];
868 if a_ki.abs() < 1e-15 {
869 continue;
870 }
871 for j in 0..bc {
872 result[[i, j]] += a_ki * b[[k, j]];
873 }
874 }
875 }
876 result
877}
878
879fn mat_mul_abt(a: &Array2<f64>, b: &Array2<f64>) -> Array2<f64> {
881 let (ar, ac) = a.dim();
882 let (br, _bc) = b.dim();
883 let mut result = Array2::<f64>::zeros((ar, br));
884 for i in 0..ar {
885 for j in 0..br {
886 let mut sum = 0.0;
887 for k in 0..ac {
888 sum += a[[i, k]] * b[[j, k]];
889 }
890 result[[i, j]] = sum;
891 }
892 }
893 result
894}
895
896#[cfg(test)]
901mod tests {
902 use super::*;
903
904 fn sample_docs() -> Vec<Vec<&'static str>> {
905 vec![
906 vec!["machine", "learning", "algorithm", "data", "model"],
907 vec!["deep", "learning", "neural", "network", "training"],
908 vec!["natural", "language", "processing", "text", "word"],
909 vec!["cat", "dog", "pet", "animal", "food"],
910 vec!["pet", "care", "food", "animal", "home"],
911 vec!["dog", "cat", "play", "park", "fun"],
912 ]
913 }
914
915 #[test]
916 fn test_gibbs_lda_fit() {
917 let docs = sample_docs();
918 let config = GibbsLdaConfig {
919 n_topics: 2,
920 n_iterations: 50,
921 seed: Some(42),
922 ..Default::default()
923 };
924 let mut lda = GibbsLda::new(config);
925 lda.fit(&docs).expect("fit failed");
926 assert!(lda.is_fitted());
927 }
928
929 #[test]
930 fn test_gibbs_lda_top_words() {
931 let docs = sample_docs();
932 let config = GibbsLdaConfig {
933 n_topics: 2,
934 n_iterations: 100,
935 seed: Some(42),
936 ..Default::default()
937 };
938 let mut lda = GibbsLda::new(config);
939 lda.fit(&docs).expect("fit failed");
940
941 let topics = lda.top_words(5);
942 assert_eq!(topics.len(), 2);
943 for topic in &topics {
944 assert_eq!(topic.len(), 5);
945 let prob_sum: f64 = topic.iter().map(|(_, p)| p).sum();
947 assert!(prob_sum > 0.0);
948 }
949 }
950
951 #[test]
952 fn test_gibbs_lda_doc_topic_distribution() {
953 let docs = sample_docs();
954 let config = GibbsLdaConfig {
955 n_topics: 2,
956 n_iterations: 50,
957 seed: Some(42),
958 ..Default::default()
959 };
960 let mut lda = GibbsLda::new(config);
961 lda.fit(&docs).expect("fit failed");
962
963 let dist = lda.doc_topic_distribution(0).expect("dist failed");
964 assert_eq!(dist.len(), 2);
965 let sum: f64 = dist.iter().sum();
966 assert!((sum - 1.0).abs() < 1e-6);
967 }
968
969 #[test]
970 fn test_gibbs_lda_topic_word_distribution() {
971 let docs = sample_docs();
972 let config = GibbsLdaConfig {
973 n_topics: 2,
974 n_iterations: 50,
975 seed: Some(42),
976 ..Default::default()
977 };
978 let mut lda = GibbsLda::new(config);
979 lda.fit(&docs).expect("fit failed");
980
981 let dist = lda.topic_word_distribution(0).expect("dist failed");
982 let sum: f64 = dist.iter().sum();
983 assert!((sum - 1.0).abs() < 1e-6);
984 }
985
986 #[test]
987 fn test_gibbs_lda_doc_topic_matrix() {
988 let docs = sample_docs();
989 let config = GibbsLdaConfig {
990 n_topics: 3,
991 n_iterations: 50,
992 seed: Some(42),
993 ..Default::default()
994 };
995 let mut lda = GibbsLda::new(config);
996 lda.fit(&docs).expect("fit failed");
997
998 let matrix = lda.doc_topic_matrix().expect("matrix failed");
999 assert_eq!(matrix.dim(), (6, 3));
1000
1001 for i in 0..6 {
1003 let sum: f64 = matrix.row(i).iter().sum();
1004 assert!((sum - 1.0).abs() < 1e-6);
1005 }
1006 }
1007
1008 #[test]
1009 fn test_gibbs_lda_empty_corpus() {
1010 let config = GibbsLdaConfig::default();
1011 let mut lda = GibbsLda::new(config);
1012 let result = lda.fit(&[]);
1013 assert!(result.is_err());
1014 }
1015
1016 #[test]
1017 fn test_gibbs_lda_not_fitted() {
1018 let lda = GibbsLda::new(GibbsLdaConfig::default());
1019 assert!(lda.doc_topic_distribution(0).is_err());
1020 assert!(lda.topic_word_distribution(0).is_err());
1021 }
1022
1023 #[test]
1024 fn test_nmf_fit() {
1025 let matrix = Array2::from_shape_vec(
1027 (4, 5),
1028 vec![
1029 1.0, 2.0, 0.0, 0.0, 0.0, 2.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 2.0, 1.0, 0.0,
1030 0.0, 2.0, 1.0, 2.0,
1031 ],
1032 )
1033 .expect("matrix creation failed");
1034
1035 let vocab: Vec<String> = vec!["ml", "deep", "cat", "dog", "pet"]
1036 .into_iter()
1037 .map(String::from)
1038 .collect();
1039
1040 let config = NmfConfig {
1041 n_topics: 2,
1042 max_iter: 100,
1043 ..Default::default()
1044 };
1045
1046 let mut nmf = NmfTopicModel::new(config);
1047 nmf.fit(&matrix, &vocab).expect("nmf fit failed");
1048 assert!(nmf.is_fitted());
1049 }
1050
1051 #[test]
1052 fn test_nmf_top_words() {
1053 let matrix = Array2::from_shape_vec(
1054 (4, 5),
1055 vec![
1056 1.0, 2.0, 0.0, 0.0, 0.0, 2.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 2.0, 1.0, 0.0,
1057 0.0, 2.0, 1.0, 2.0,
1058 ],
1059 )
1060 .expect("matrix creation failed");
1061
1062 let vocab: Vec<String> = vec!["ml", "deep", "cat", "dog", "pet"]
1063 .into_iter()
1064 .map(String::from)
1065 .collect();
1066
1067 let config = NmfConfig {
1068 n_topics: 2,
1069 max_iter: 100,
1070 ..Default::default()
1071 };
1072
1073 let mut nmf = NmfTopicModel::new(config);
1074 nmf.fit(&matrix, &vocab).expect("nmf fit failed");
1075
1076 let topics = nmf.top_words(3).expect("top_words failed");
1077 assert_eq!(topics.len(), 2);
1078 for topic in &topics {
1079 assert!(topic.len() <= 3);
1080 }
1081 }
1082
1083 #[test]
1084 fn test_nmf_convergence() {
1085 let matrix = Array2::from_shape_vec(
1086 (3, 4),
1087 vec![1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0],
1088 )
1089 .expect("matrix creation failed");
1090
1091 let vocab: Vec<String> = (0..4).map(|i| format!("w{}", i)).collect();
1092
1093 let config = NmfConfig {
1094 n_topics: 2,
1095 max_iter: 200,
1096 ..Default::default()
1097 };
1098
1099 let mut nmf = NmfTopicModel::new(config);
1100 nmf.fit(&matrix, &vocab).expect("nmf fit failed");
1101
1102 let errors = nmf.error_history();
1103 assert!(!errors.is_empty());
1104 if errors.len() >= 2 {
1106 assert!(
1107 errors.last().copied().unwrap_or(f64::MAX)
1108 <= errors.first().copied().unwrap_or(0.0) + 1e-6
1109 );
1110 }
1111 }
1112
1113 #[test]
1114 fn test_nmf_not_fitted() {
1115 let nmf = NmfTopicModel::new(NmfConfig::default());
1116 assert!(nmf.doc_topic_matrix().is_err());
1117 assert!(nmf.topic_term_matrix().is_err());
1118 }
1119
1120 #[test]
1121 fn test_coherence_cv() {
1122 let topic_words = vec![
1123 vec![
1124 "machine".to_string(),
1125 "learning".to_string(),
1126 "algorithm".to_string(),
1127 ],
1128 vec!["cat".to_string(), "dog".to_string(), "pet".to_string()],
1129 ];
1130
1131 let documents = vec![
1132 vec![
1133 "machine".to_string(),
1134 "learning".to_string(),
1135 "algorithm".to_string(),
1136 ],
1137 vec![
1138 "deep".to_string(),
1139 "learning".to_string(),
1140 "neural".to_string(),
1141 ],
1142 vec!["cat".to_string(), "dog".to_string(), "pet".to_string()],
1143 vec!["cat".to_string(), "play".to_string(), "fun".to_string()],
1144 ];
1145
1146 let scorer = TopicCoherenceScorer::new();
1147 let cv = scorer
1148 .cv_coherence(&topic_words, &documents)
1149 .expect("cv failed");
1150 assert!(cv.is_finite());
1152 }
1153
1154 #[test]
1155 fn test_coherence_umass() {
1156 let topic_words = vec![
1157 vec!["machine".to_string(), "learning".to_string()],
1158 vec!["cat".to_string(), "dog".to_string()],
1159 ];
1160
1161 let documents = vec![
1162 vec!["machine".to_string(), "learning".to_string()],
1163 vec!["cat".to_string(), "dog".to_string()],
1164 ];
1165
1166 let scorer = TopicCoherenceScorer::new();
1167 let umass = scorer
1168 .umass_coherence(&topic_words, &documents)
1169 .expect("umass failed");
1170 assert!(umass.is_finite());
1171 }
1172
1173 #[test]
1174 fn test_coherence_empty() {
1175 let scorer = TopicCoherenceScorer::new();
1176 assert!(scorer.cv_coherence(&[], &[]).is_err());
1177 assert!(scorer.umass_coherence(&[], &[]).is_err());
1178 }
1179
1180 #[test]
1181 fn test_select_n_topics() {
1182 let docs = sample_docs();
1183 let (best_k, scores) = select_n_topics(&docs, 2, 3, 30, 42).expect("select failed");
1184 assert!((2..=3).contains(&best_k));
1185 assert_eq!(scores.len(), 2);
1186 }
1187
1188 #[test]
1189 fn test_select_n_topics_invalid_range() {
1190 let docs = sample_docs();
1191 assert!(select_n_topics(&docs, 5, 2, 30, 42).is_err());
1192 }
1193
1194 #[test]
1195 fn test_lda_vocabulary() {
1196 let docs = sample_docs();
1197 let config = GibbsLdaConfig {
1198 n_topics: 2,
1199 n_iterations: 10,
1200 seed: Some(42),
1201 ..Default::default()
1202 };
1203 let mut lda = GibbsLda::new(config);
1204 lda.fit(&docs).expect("fit failed");
1205
1206 let vocab = lda.vocabulary();
1207 assert!(vocab.contains_key("machine"));
1208 assert!(vocab.contains_key("cat"));
1209 }
1210
1211 #[test]
1212 fn test_lda_n_topics() {
1213 let config = GibbsLdaConfig {
1214 n_topics: 5,
1215 ..Default::default()
1216 };
1217 let lda = GibbsLda::new(config);
1218 assert_eq!(lda.n_topics(), 5);
1219 }
1220
1221 #[test]
1222 fn test_coherence_window_size() {
1223 let scorer = TopicCoherenceScorer::new().with_window_size(5);
1224 assert_eq!(scorer.window_size, 5);
1225 }
1226
1227 #[test]
1228 fn test_nmf_doc_topic_matrix() {
1229 let matrix = Array2::from_shape_vec(
1230 (3, 4),
1231 vec![1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0],
1232 )
1233 .expect("matrix creation failed");
1234
1235 let vocab: Vec<String> = (0..4).map(|i| format!("w{}", i)).collect();
1236 let config = NmfConfig {
1237 n_topics: 2,
1238 max_iter: 50,
1239 ..Default::default()
1240 };
1241
1242 let mut nmf = NmfTopicModel::new(config);
1243 nmf.fit(&matrix, &vocab).expect("fit failed");
1244
1245 let dtm = nmf.doc_topic_matrix().expect("dtm failed");
1246 assert_eq!(dtm.dim(), (3, 2));
1247
1248 for &v in dtm.iter() {
1250 assert!(v >= 0.0);
1251 }
1252 }
1253}