1use crate::error::{Result, TextError};
166use scirs2_core::ndarray::{Array1, Array2, Axis};
167use scirs2_core::random::prelude::*;
168use scirs2_core::random::seq::SliceRandom;
169use scirs2_core::random::{rngs::StdRng, SeedableRng};
170use std::collections::HashMap;
171
172#[derive(Debug, Clone, Copy, PartialEq)]
174pub enum LdaLearningMethod {
175 Batch,
177 Online,
179}
180
181#[derive(Debug, Clone)]
183pub struct LdaConfig {
184 pub ntopics: usize,
186 pub doc_topic_prior: Option<f64>,
188 pub topic_word_prior: Option<f64>,
190 pub learning_method: LdaLearningMethod,
192 pub learning_decay: f64,
194 pub learning_offset: f64,
196 pub maxiter: usize,
198 pub batch_size: usize,
200 pub mean_change_tol: f64,
202 pub max_doc_update_iter: usize,
204 pub random_seed: Option<u64>,
206}
207
208impl Default for LdaConfig {
209 fn default() -> Self {
210 Self {
211 ntopics: 10,
212 doc_topic_prior: None, topic_word_prior: None, learning_method: LdaLearningMethod::Batch,
215 learning_decay: 0.7,
216 learning_offset: 10.0,
217 maxiter: 10,
218 batch_size: 128,
219 mean_change_tol: 1e-3,
220 max_doc_update_iter: 100,
221 random_seed: None,
222 }
223 }
224}
225
226#[derive(Debug, Clone)]
228pub struct Topic {
229 pub id: usize,
231 pub top_words: Vec<(String, f64)>,
233 pub coherence: Option<f64>,
235}
236
237pub struct LatentDirichletAllocation {
239 config: LdaConfig,
240 components: Option<Array2<f64>>,
242 exp_dirichlet_component: Option<Array2<f64>>,
244 #[allow(dead_code)]
246 vocabulary: Option<HashMap<usize, String>>,
247 n_documents: usize,
249 n_iter: usize,
251 #[allow(dead_code)]
253 bound: Option<Vec<f64>>,
254}
255
256impl LatentDirichletAllocation {
257 pub fn new(config: LdaConfig) -> Self {
259 Self {
260 config,
261 components: None,
262 exp_dirichlet_component: None,
263 vocabulary: None,
264 n_documents: 0,
265 n_iter: 0,
266 bound: None,
267 }
268 }
269
270 pub fn with_ntopics(ntopics: usize) -> Self {
272 let config = LdaConfig {
273 ntopics,
274 ..Default::default()
275 };
276 Self::new(config)
277 }
278
279 pub fn fit(&mut self, doc_termmatrix: &Array2<f64>) -> Result<&mut Self> {
281 if doc_termmatrix.nrows() == 0 || doc_termmatrix.ncols() == 0 {
282 return Err(TextError::InvalidInput(
283 "Document-term _matrix cannot be empty".to_string(),
284 ));
285 }
286
287 let n_samples = doc_termmatrix.nrows();
288 let n_features = doc_termmatrix.ncols();
289
290 let doc_topic_prior = self
292 .config
293 .doc_topic_prior
294 .unwrap_or(1.0 / self.config.ntopics as f64);
295 let topic_word_prior = self
296 .config
297 .topic_word_prior
298 .unwrap_or(1.0 / self.config.ntopics as f64);
299
300 let mut rng = self.create_rng();
302 self.components = Some(self.initialize_components(n_features, &mut rng));
303
304 match self.config.learning_method {
306 LdaLearningMethod::Batch => {
307 self.fit_batch(doc_termmatrix, doc_topic_prior, topic_word_prior)?;
308 }
309 LdaLearningMethod::Online => {
310 self.fit_online(doc_termmatrix, doc_topic_prior, topic_word_prior)?;
311 }
312 }
313
314 self.n_documents = n_samples;
315 Ok(self)
316 }
317
318 pub fn transform(&self, doc_termmatrix: &Array2<f64>) -> Result<Array2<f64>> {
320 if self.components.is_none() {
321 return Err(TextError::ModelNotFitted(
322 "LDA model not fitted yet".to_string(),
323 ));
324 }
325
326 let n_samples = doc_termmatrix.nrows();
327 let ntopics = self.config.ntopics;
328
329 let mut doc_topic_distr = Array2::zeros((n_samples, ntopics));
331
332 let exp_dirichlet_component = self.get_exp_dirichlet_component()?;
334
335 let doc_topic_prior = self.config.doc_topic_prior.unwrap_or(1.0 / ntopics as f64);
337
338 for (doc_idx, doc) in doc_termmatrix.axis_iter(Axis(0)).enumerate() {
340 let mut gamma = Array1::from_elem(ntopics, doc_topic_prior);
341 self.update_doc_distribution(
342 &doc.to_owned(),
343 &mut gamma,
344 exp_dirichlet_component,
345 doc_topic_prior,
346 )?;
347
348 let gamma_sum = gamma.sum();
350 if gamma_sum > 0.0 {
351 gamma /= gamma_sum;
352 }
353
354 doc_topic_distr.row_mut(doc_idx).assign(&gamma);
355 }
356
357 Ok(doc_topic_distr)
358 }
359
360 pub fn fit_transform(&mut self, doc_termmatrix: &Array2<f64>) -> Result<Array2<f64>> {
362 self.fit(doc_termmatrix)?;
363 self.transform(doc_termmatrix)
364 }
365
366 pub fn get_topics(
368 &self,
369 n_top_words: usize,
370 vocabulary: &HashMap<usize, String>,
371 ) -> Result<Vec<Topic>> {
372 if self.components.is_none() {
373 return Err(TextError::ModelNotFitted(
374 "LDA model not fitted yet".to_string(),
375 ));
376 }
377
378 let components = self.components.as_ref().unwrap();
379 let mut topics = Vec::new();
380
381 for (topic_idx, topic_dist) in components.axis_iter(Axis(0)).enumerate() {
382 let mut word_scores: Vec<(usize, f64)> = topic_dist
384 .iter()
385 .enumerate()
386 .map(|(idx, &score)| (idx, score))
387 .collect();
388
389 word_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
390
391 let top_words: Vec<(String, f64)> = word_scores
393 .into_iter()
394 .take(n_top_words)
395 .filter_map(|(idx, score)| vocabulary.get(&idx).map(|word| (word.clone(), score)))
396 .collect();
397
398 topics.push(Topic {
399 id: topic_idx,
400 top_words,
401 coherence: None,
402 });
403 }
404
405 Ok(topics)
406 }
407
408 pub fn get_topic_word_distribution(&self) -> Option<&Array2<f64>> {
410 self.components.as_ref()
411 }
412
413 fn create_rng(&self) -> scirs2_core::random::rngs::StdRng {
416 use scirs2_core::random::SeedableRng;
417 match self.config.random_seed {
418 Some(seed) => scirs2_core::random::rngs::StdRng::seed_from_u64(seed),
419 None => {
420 let mut temp_rng = scirs2_core::random::rng();
421 scirs2_core::random::rngs::StdRng::from_rng(&mut temp_rng)
422 }
423 }
424 }
425
426 fn initialize_components(
427 &self,
428 n_features: usize,
429 rng: &mut scirs2_core::random::rngs::StdRng,
430 ) -> Array2<f64> {
431 let mut components = Array2::zeros((self.config.ntopics, n_features));
434 for mut row in components.axis_iter_mut(Axis(0)) {
435 for val in row.iter_mut() {
436 *val = rng.random_range(0.0..1.0);
437 }
438 let row_sum: f64 = row.sum();
440 if row_sum > 0.0 {
441 row /= row_sum;
442 }
443 }
444
445 components
446 }
447
448 fn get_exp_dirichlet_component(&self) -> Result<&Array2<f64>> {
449 if self.exp_dirichlet_component.is_none() {
450 return Err(TextError::ModelNotFitted(
451 "Components not initialized".to_string(),
452 ));
453 }
454 Ok(self.exp_dirichlet_component.as_ref().unwrap())
455 }
456
457 fn fit_batch(
458 &mut self,
459 doc_term_matrix: &Array2<f64>,
460 doc_topic_prior: f64,
461 topic_word_prior: f64,
462 ) -> Result<()> {
463 let n_samples = doc_term_matrix.nrows();
464 let ntopics = self.config.ntopics;
465
466 let mut doc_topic_distr = Array2::from_elem((n_samples, ntopics), doc_topic_prior);
468
469 for iter in 0..self.config.maxiter {
471 self.update_exp_dirichlet_component()?;
473
474 let mut mean_change = 0.0;
476 for (doc_idx, doc) in doc_term_matrix.axis_iter(Axis(0)).enumerate() {
477 let mut gamma = doc_topic_distr.row(doc_idx).to_owned();
478 let old_gamma = gamma.clone();
479
480 self.update_doc_distribution(
481 &doc.to_owned(),
482 &mut gamma,
483 self.get_exp_dirichlet_component()?,
484 doc_topic_prior,
485 )?;
486
487 let change: f64 = (&gamma - &old_gamma).iter().map(|&x| x.abs()).sum();
489 mean_change += change / ntopics as f64;
490
491 doc_topic_distr.row_mut(doc_idx).assign(&gamma);
492 }
493 mean_change /= n_samples as f64;
494
495 self.update_topic_distribution(doc_term_matrix, &doc_topic_distr, topic_word_prior)?;
497
498 if mean_change < self.config.mean_change_tol {
500 break;
501 }
502
503 self.n_iter = iter + 1;
504 }
505
506 Ok(())
507 }
508
509 fn fit_online(
510 &mut self,
511 doc_term_matrix: &Array2<f64>,
512 doc_topic_prior: f64,
513 topic_word_prior: f64,
514 ) -> Result<()> {
515 let (n_samples, n_features) = doc_term_matrix.dim();
516 self.vocabulary
517 .get_or_insert_with(|| (0..n_features).map(|i| (i, format!("word_{i}"))).collect());
518 self.bound.get_or_insert_with(Vec::new);
519
520 if self.components.is_none() {
522 let mut rng = if let Some(seed) = self.config.random_seed {
523 StdRng::seed_from_u64(seed)
524 } else {
525 StdRng::from_rng(&mut scirs2_core::random::rng())
526 };
527
528 let mut components = Array2::<f64>::zeros((self.config.ntopics, n_features));
529 for i in 0..self.config.ntopics {
530 for j in 0..n_features {
531 components[[i, j]] = rng.random::<f64>() + topic_word_prior;
532 }
533 }
534 self.components = Some(components);
535 }
536
537 let batch_size = self.config.batch_size.min(n_samples);
538 let n_batches = n_samples.div_ceil(batch_size);
539
540 for epoch in 0..self.config.maxiter {
541 let mut total_bound = 0.0;
542
543 let mut doc_indices: Vec<usize> = (0..n_samples).collect();
545 let mut rng = if let Some(seed) = self.config.random_seed {
546 StdRng::seed_from_u64(seed + epoch as u64)
547 } else {
548 StdRng::from_rng(&mut scirs2_core::random::rng())
549 };
550 doc_indices.shuffle(&mut rng);
551
552 for batch_idx in 0..n_batches {
553 let start_idx = batch_idx * batch_size;
554 let end_idx = ((batch_idx + 1) * batch_size).min(n_samples);
555
556 let batch_docs: Vec<usize> = doc_indices[start_idx..end_idx].to_vec();
558
559 let mut batch_gamma = Array2::<f64>::zeros((batch_docs.len(), self.config.ntopics));
561 let mut batch_bound = 0.0;
562
563 for (local_idx, &doc_idx) in batch_docs.iter().enumerate() {
564 let doc = doc_term_matrix.row(doc_idx);
565 let mut gamma = Array1::<f64>::from_elem(self.config.ntopics, doc_topic_prior);
566
567 let components = self.components.as_ref().unwrap();
569 let exp_topic_word_distr = components.map(|x| x.exp());
570 self.update_doc_distribution(
571 &doc.to_owned(),
572 &mut gamma,
573 &exp_topic_word_distr,
574 doc_topic_prior,
575 )?;
576
577 batch_gamma.row_mut(local_idx).assign(&gamma);
578
579 batch_bound += gamma.sum();
581 }
582
583 let learning_rate = self.compute_learning_rate(epoch * n_batches + batch_idx);
585 self.update_topic_word_distribution(
586 &batch_docs,
587 doc_term_matrix,
588 &batch_gamma,
589 topic_word_prior,
590 learning_rate,
591 n_samples,
592 )?;
593
594 total_bound += batch_bound;
595 }
596
597 if let Some(ref mut bound) = self.bound {
599 bound.push(total_bound / n_samples as f64);
600 }
601
602 if let Some(ref bound) = self.bound {
604 if bound.len() > 1 {
605 let current_bound = bound[bound.len() - 1];
606 let prev_bound = bound[bound.len() - 2];
607 let change = (current_bound - prev_bound).abs();
608 if change < self.config.mean_change_tol {
609 break;
610 }
611 }
612 }
613
614 self.n_iter = epoch + 1;
615 }
616
617 self.n_documents = n_samples;
618 Ok(())
619 }
620
621 fn compute_learning_rate(&self, iteration: usize) -> f64 {
623 (self.config.learning_offset + iteration as f64).powf(-self.config.learning_decay)
624 }
625
626 fn update_topic_word_distribution(
628 &mut self,
629 batch_docs: &[usize],
630 doc_term_matrix: &Array2<f64>,
631 batch_gamma: &Array2<f64>,
632 topic_word_prior: f64,
633 learning_rate: f64,
634 total_docs: usize,
635 ) -> Result<()> {
636 let batch_size = batch_docs.len();
637 let n_features = doc_term_matrix.ncols();
638
639 if let Some(ref mut components) = self.components {
640 let mut batch_stats = Array2::<f64>::zeros((self.config.ntopics, n_features));
642
643 for (local_idx, &doc_idx) in batch_docs.iter().enumerate() {
644 let doc = doc_term_matrix.row(doc_idx);
645 let gamma = batch_gamma.row(local_idx);
646 let gamma_sum = gamma.sum();
647
648 for (word_idx, &count) in doc.iter().enumerate() {
649 if count > 0.0 {
650 for topic_idx in 0..self.config.ntopics {
651 let phi = gamma[topic_idx] / gamma_sum;
652 batch_stats[[topic_idx, word_idx]] += count * phi;
653 }
654 }
655 }
656 }
657
658 let scale_factor = total_docs as f64 / batch_size as f64;
660 batch_stats.mapv_inplace(|x| x * scale_factor);
661
662 for topic_idx in 0..self.config.ntopics {
664 for word_idx in 0..n_features {
665 let old_val = components[[topic_idx, word_idx]];
666 let new_val = topic_word_prior + batch_stats[[topic_idx, word_idx]];
667 components[[topic_idx, word_idx]] =
668 (1.0 - learning_rate) * old_val + learning_rate * new_val;
669 }
670 }
671 }
672
673 Ok(())
674 }
675
676 fn update_doc_distribution(
677 &self,
678 doc: &Array1<f64>,
679 gamma: &mut Array1<f64>,
680 exp_topic_word_distr: &Array2<f64>,
681 doc_topic_prior: f64,
682 ) -> Result<()> {
683 for _ in 0..self.config.max_doc_update_iter {
685 let old_gamma = gamma.clone();
686
687 gamma.fill(doc_topic_prior);
689
690 for (word_idx, &count) in doc.iter().enumerate() {
692 }
694
695 let change: f64 = (&*gamma - &old_gamma).iter().map(|&x| x.abs()).sum();
697 if change < self.config.mean_change_tol {
698 break;
699 }
700 }
701
702 Ok(())
703 }
704
705 fn update_topic_distribution(
706 &mut self,
707 doc_term_matrix: &Array2<f64>,
708 doc_topic_distr: &Array2<f64>,
709 topic_word_prior: f64,
710 ) -> Result<()> {
711 if let Some(ref mut components) = self.components {
712 let _n_features = doc_term_matrix.ncols();
713
714 components.fill(topic_word_prior);
716
717 for (doc_idx, doc) in doc_term_matrix.axis_iter(Axis(0)).enumerate() {
719 let doc_topics = doc_topic_distr.row(doc_idx);
720
721 for (word_idx, &count) in doc.iter().enumerate() {
722 if count > 0.0 {
723 for topic_idx in 0..self.config.ntopics {
724 components[[topic_idx, word_idx]] += count * doc_topics[topic_idx];
725 }
726 }
727 }
728 }
729
730 for mut topic in components.axis_iter_mut(Axis(0)) {
732 let topic_sum = topic.sum();
733 if topic_sum > 0.0 {
734 topic /= topic_sum;
735 }
736 }
737 }
738
739 Ok(())
740 }
741
742 fn update_exp_dirichlet_component(&mut self) -> Result<()> {
743 if let Some(ref components) = self.components {
744 self.exp_dirichlet_component = Some(components.clone());
747 }
748 Ok(())
749 }
750}
751
752pub struct LdaBuilder {
754 config: LdaConfig,
755}
756
757impl LdaBuilder {
758 pub fn new() -> Self {
760 Self {
761 config: LdaConfig::default(),
762 }
763 }
764
765 pub fn ntopics(mut self, ntopics: usize) -> Self {
767 self.config.ntopics = ntopics;
768 self
769 }
770
771 pub fn doc_topic_prior(mut self, prior: f64) -> Self {
773 self.config.doc_topic_prior = Some(prior);
774 self
775 }
776
777 pub fn topic_word_prior(mut self, prior: f64) -> Self {
779 self.config.topic_word_prior = Some(prior);
780 self
781 }
782
783 pub fn learning_method(mut self, method: LdaLearningMethod) -> Self {
785 self.config.learning_method = method;
786 self
787 }
788
789 pub fn maxiter(mut self, maxiter: usize) -> Self {
791 self.config.maxiter = maxiter;
792 self
793 }
794
795 pub fn random_seed(mut self, seed: u64) -> Self {
797 self.config.random_seed = Some(seed);
798 self
799 }
800
801 pub fn build(self) -> LatentDirichletAllocation {
803 LatentDirichletAllocation::new(self.config)
804 }
805}
806
807impl Default for LdaBuilder {
808 fn default() -> Self {
809 Self::new()
810 }
811}
812
813#[cfg(test)]
814mod tests {
815 use super::*;
816
817 #[test]
818 fn test_lda_creation() {
819 let lda = LatentDirichletAllocation::with_ntopics(5);
820 assert_eq!(lda.config.ntopics, 5);
821 }
822
823 #[test]
824 fn test_lda_builder() {
825 let lda = LdaBuilder::new()
826 .ntopics(10)
827 .doc_topic_prior(0.1)
828 .maxiter(20)
829 .random_seed(42)
830 .build();
831
832 assert_eq!(lda.config.ntopics, 10);
833 assert_eq!(lda.config.doc_topic_prior, Some(0.1));
834 assert_eq!(lda.config.maxiter, 20);
835 assert_eq!(lda.config.random_seed, Some(42));
836 }
837
838 #[test]
839 fn test_lda_fit_transform() {
840 let doc_term_matrix = Array2::from_shape_vec(
842 (4, 6),
843 vec![
844 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, ],
849 )
850 .unwrap();
851
852 let mut lda = LatentDirichletAllocation::with_ntopics(2);
853 let doc_topics = lda.fit_transform(&doc_term_matrix).unwrap();
854
855 assert_eq!(doc_topics.nrows(), 4);
856 assert_eq!(doc_topics.ncols(), 2);
857
858 for row in doc_topics.axis_iter(Axis(0)) {
860 let sum: f64 = row.sum();
861 assert!((sum - 1.0).abs() < 1e-6);
862 }
863 }
864
865 #[test]
866 fn test_get_topics() {
867 let doc_term_matrix = Array2::from_shape_vec(
868 (4, 3),
869 vec![2.0, 1.0, 0.0, 0.0, 2.0, 1.0, 1.0, 0.0, 2.0, 2.0, 1.0, 1.0],
870 )
871 .unwrap();
872
873 let mut vocabulary = HashMap::new();
874 vocabulary.insert(0, "word1".to_string());
875 vocabulary.insert(1, "word2".to_string());
876 vocabulary.insert(2, "word3".to_string());
877
878 let mut lda = LatentDirichletAllocation::with_ntopics(2);
879 lda.fit(&doc_term_matrix).unwrap();
880
881 let topics = lda.get_topics(3, &vocabulary).unwrap();
882 assert_eq!(topics.len(), 2);
883
884 for topic in &topics {
885 assert_eq!(topic.top_words.len(), 3);
886 }
887 }
888}