1use crate::error::{Result, TextError};
166use scirs2_core::ndarray::{Array1, Array2, Axis};
167use scirs2_core::random::prelude::*;
168use scirs2_core::random::seq::SliceRandom;
169use scirs2_core::random::{rngs::StdRng, SeedableRng};
170use std::collections::HashMap;
171
172#[derive(Debug, Clone, Copy, PartialEq)]
174pub enum LdaLearningMethod {
175 Batch,
177 Online,
179}
180
181#[derive(Debug, Clone)]
183pub struct LdaConfig {
184 pub ntopics: usize,
186 pub doc_topic_prior: Option<f64>,
188 pub topic_word_prior: Option<f64>,
190 pub learning_method: LdaLearningMethod,
192 pub learning_decay: f64,
194 pub learning_offset: f64,
196 pub maxiter: usize,
198 pub batch_size: usize,
200 pub mean_change_tol: f64,
202 pub max_doc_update_iter: usize,
204 pub random_seed: Option<u64>,
206}
207
208impl Default for LdaConfig {
209 fn default() -> Self {
210 Self {
211 ntopics: 10,
212 doc_topic_prior: None, topic_word_prior: None, learning_method: LdaLearningMethod::Batch,
215 learning_decay: 0.7,
216 learning_offset: 10.0,
217 maxiter: 10,
218 batch_size: 128,
219 mean_change_tol: 1e-3,
220 max_doc_update_iter: 100,
221 random_seed: None,
222 }
223 }
224}
225
226#[derive(Debug, Clone)]
228pub struct Topic {
229 pub id: usize,
231 pub top_words: Vec<(String, f64)>,
233 pub coherence: Option<f64>,
235}
236
237pub struct LatentDirichletAllocation {
239 config: LdaConfig,
240 components: Option<Array2<f64>>,
242 exp_dirichlet_component: Option<Array2<f64>>,
244 #[allow(dead_code)]
246 vocabulary: Option<HashMap<usize, String>>,
247 n_documents: usize,
249 n_iter: usize,
251 #[allow(dead_code)]
253 bound: Option<Vec<f64>>,
254}
255
256impl LatentDirichletAllocation {
257 pub fn new(config: LdaConfig) -> Self {
259 Self {
260 config,
261 components: None,
262 exp_dirichlet_component: None,
263 vocabulary: None,
264 n_documents: 0,
265 n_iter: 0,
266 bound: None,
267 }
268 }
269
270 pub fn with_ntopics(ntopics: usize) -> Self {
272 let config = LdaConfig {
273 ntopics,
274 ..Default::default()
275 };
276 Self::new(config)
277 }
278
279 pub fn fit(&mut self, doc_termmatrix: &Array2<f64>) -> Result<&mut Self> {
281 if doc_termmatrix.nrows() == 0 || doc_termmatrix.ncols() == 0 {
282 return Err(TextError::InvalidInput(
283 "Document-term _matrix cannot be empty".to_string(),
284 ));
285 }
286
287 let n_samples = doc_termmatrix.nrows();
288 let n_features = doc_termmatrix.ncols();
289
290 let doc_topic_prior = self
292 .config
293 .doc_topic_prior
294 .unwrap_or(1.0 / self.config.ntopics as f64);
295 let topic_word_prior = self
296 .config
297 .topic_word_prior
298 .unwrap_or(1.0 / self.config.ntopics as f64);
299
300 let mut rng = self.create_rng();
302 self.components = Some(self.initialize_components(n_features, &mut rng));
303
304 match self.config.learning_method {
306 LdaLearningMethod::Batch => {
307 self.fit_batch(doc_termmatrix, doc_topic_prior, topic_word_prior)?;
308 }
309 LdaLearningMethod::Online => {
310 self.fit_online(doc_termmatrix, doc_topic_prior, topic_word_prior)?;
311 }
312 }
313
314 self.n_documents = n_samples;
315 Ok(self)
316 }
317
318 pub fn transform(&self, doc_termmatrix: &Array2<f64>) -> Result<Array2<f64>> {
320 if self.components.is_none() {
321 return Err(TextError::ModelNotFitted(
322 "LDA model not fitted yet".to_string(),
323 ));
324 }
325
326 let n_samples = doc_termmatrix.nrows();
327 let ntopics = self.config.ntopics;
328
329 let mut doc_topic_distr = Array2::zeros((n_samples, ntopics));
331
332 let exp_dirichlet_component = self.get_exp_dirichlet_component()?;
334
335 let doc_topic_prior = self.config.doc_topic_prior.unwrap_or(1.0 / ntopics as f64);
337
338 for (doc_idx, doc) in doc_termmatrix.axis_iter(Axis(0)).enumerate() {
340 let mut gamma = Array1::from_elem(ntopics, doc_topic_prior);
341 self.update_doc_distribution(
342 &doc.to_owned(),
343 &mut gamma,
344 exp_dirichlet_component,
345 doc_topic_prior,
346 )?;
347
348 let gamma_sum = gamma.sum();
350 if gamma_sum > 0.0 {
351 gamma /= gamma_sum;
352 }
353
354 doc_topic_distr.row_mut(doc_idx).assign(&gamma);
355 }
356
357 Ok(doc_topic_distr)
358 }
359
360 pub fn fit_transform(&mut self, doc_termmatrix: &Array2<f64>) -> Result<Array2<f64>> {
362 self.fit(doc_termmatrix)?;
363 self.transform(doc_termmatrix)
364 }
365
366 pub fn get_topics(
368 &self,
369 n_top_words: usize,
370 vocabulary: &HashMap<usize, String>,
371 ) -> Result<Vec<Topic>> {
372 if self.components.is_none() {
373 return Err(TextError::ModelNotFitted(
374 "LDA model not fitted yet".to_string(),
375 ));
376 }
377
378 let components = self.components.as_ref().expect("Operation failed");
379 let mut topics = Vec::new();
380
381 for (topic_idx, topic_dist) in components.axis_iter(Axis(0)).enumerate() {
382 let mut word_scores: Vec<(usize, f64)> = topic_dist
384 .iter()
385 .enumerate()
386 .map(|(idx, &score)| (idx, score))
387 .collect();
388
389 word_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).expect("Operation failed"));
390
391 let top_words: Vec<(String, f64)> = word_scores
393 .into_iter()
394 .take(n_top_words)
395 .filter_map(|(idx, score)| vocabulary.get(&idx).map(|word| (word.clone(), score)))
396 .collect();
397
398 topics.push(Topic {
399 id: topic_idx,
400 top_words,
401 coherence: None,
402 });
403 }
404
405 Ok(topics)
406 }
407
408 pub fn get_topic_word_distribution(&self) -> Option<&Array2<f64>> {
410 self.components.as_ref()
411 }
412
413 fn create_rng(&self) -> scirs2_core::random::rngs::StdRng {
416 use scirs2_core::random::SeedableRng;
417 match self.config.random_seed {
418 Some(seed) => scirs2_core::random::rngs::StdRng::seed_from_u64(seed),
419 None => {
420 let mut temp_rng = scirs2_core::random::rng();
421 scirs2_core::random::rngs::StdRng::from_rng(&mut temp_rng)
422 }
423 }
424 }
425
426 fn initialize_components(
427 &self,
428 n_features: usize,
429 rng: &mut scirs2_core::random::rngs::StdRng,
430 ) -> Array2<f64> {
431 let mut components = Array2::zeros((self.config.ntopics, n_features));
434 for mut row in components.axis_iter_mut(Axis(0)) {
435 for val in row.iter_mut() {
436 *val = rng.random_range(0.0..1.0);
437 }
438 let row_sum: f64 = row.sum();
440 if row_sum > 0.0 {
441 row /= row_sum;
442 }
443 }
444
445 components
446 }
447
448 fn get_exp_dirichlet_component(&self) -> Result<&Array2<f64>> {
449 if self.exp_dirichlet_component.is_none() {
450 return Err(TextError::ModelNotFitted(
451 "Components not initialized".to_string(),
452 ));
453 }
454 Ok(self
455 .exp_dirichlet_component
456 .as_ref()
457 .expect("Operation failed"))
458 }
459
460 fn fit_batch(
461 &mut self,
462 doc_term_matrix: &Array2<f64>,
463 doc_topic_prior: f64,
464 topic_word_prior: f64,
465 ) -> Result<()> {
466 let n_samples = doc_term_matrix.nrows();
467 let ntopics = self.config.ntopics;
468
469 let mut doc_topic_distr = Array2::from_elem((n_samples, ntopics), doc_topic_prior);
471
472 for iter in 0..self.config.maxiter {
474 self.update_exp_dirichlet_component()?;
476
477 let mut mean_change = 0.0;
479 for (doc_idx, doc) in doc_term_matrix.axis_iter(Axis(0)).enumerate() {
480 let mut gamma = doc_topic_distr.row(doc_idx).to_owned();
481 let old_gamma = gamma.clone();
482
483 self.update_doc_distribution(
484 &doc.to_owned(),
485 &mut gamma,
486 self.get_exp_dirichlet_component()?,
487 doc_topic_prior,
488 )?;
489
490 let change: f64 = (&gamma - &old_gamma).iter().map(|&x| x.abs()).sum();
492 mean_change += change / ntopics as f64;
493
494 doc_topic_distr.row_mut(doc_idx).assign(&gamma);
495 }
496 mean_change /= n_samples as f64;
497
498 self.update_topic_distribution(doc_term_matrix, &doc_topic_distr, topic_word_prior)?;
500
501 if mean_change < self.config.mean_change_tol {
503 break;
504 }
505
506 self.n_iter = iter + 1;
507 }
508
509 Ok(())
510 }
511
512 fn fit_online(
513 &mut self,
514 doc_term_matrix: &Array2<f64>,
515 doc_topic_prior: f64,
516 topic_word_prior: f64,
517 ) -> Result<()> {
518 let (n_samples, n_features) = doc_term_matrix.dim();
519 self.vocabulary
520 .get_or_insert_with(|| (0..n_features).map(|i| (i, format!("word_{i}"))).collect());
521 self.bound.get_or_insert_with(Vec::new);
522
523 if self.components.is_none() {
525 let mut rng = if let Some(seed) = self.config.random_seed {
526 StdRng::seed_from_u64(seed)
527 } else {
528 StdRng::from_rng(&mut scirs2_core::random::rng())
529 };
530
531 let mut components = Array2::<f64>::zeros((self.config.ntopics, n_features));
532 for i in 0..self.config.ntopics {
533 for j in 0..n_features {
534 components[[i, j]] = rng.random::<f64>() + topic_word_prior;
535 }
536 }
537 self.components = Some(components);
538 }
539
540 let batch_size = self.config.batch_size.min(n_samples);
541 let n_batches = n_samples.div_ceil(batch_size);
542
543 for epoch in 0..self.config.maxiter {
544 let mut total_bound = 0.0;
545
546 let mut doc_indices: Vec<usize> = (0..n_samples).collect();
548 let mut rng = if let Some(seed) = self.config.random_seed {
549 StdRng::seed_from_u64(seed + epoch as u64)
550 } else {
551 StdRng::from_rng(&mut scirs2_core::random::rng())
552 };
553 doc_indices.shuffle(&mut rng);
554
555 for batch_idx in 0..n_batches {
556 let start_idx = batch_idx * batch_size;
557 let end_idx = ((batch_idx + 1) * batch_size).min(n_samples);
558
559 let batch_docs: Vec<usize> = doc_indices[start_idx..end_idx].to_vec();
561
562 let mut batch_gamma = Array2::<f64>::zeros((batch_docs.len(), self.config.ntopics));
564 let mut batch_bound = 0.0;
565
566 for (local_idx, &doc_idx) in batch_docs.iter().enumerate() {
567 let doc = doc_term_matrix.row(doc_idx);
568 let mut gamma = Array1::<f64>::from_elem(self.config.ntopics, doc_topic_prior);
569
570 let components = self.components.as_ref().expect("Operation failed");
572 let exp_topic_word_distr = components.map(|x| x.exp());
573 self.update_doc_distribution(
574 &doc.to_owned(),
575 &mut gamma,
576 &exp_topic_word_distr,
577 doc_topic_prior,
578 )?;
579
580 batch_gamma.row_mut(local_idx).assign(&gamma);
581
582 batch_bound += gamma.sum();
584 }
585
586 let learning_rate = self.compute_learning_rate(epoch * n_batches + batch_idx);
588 self.update_topic_word_distribution(
589 &batch_docs,
590 doc_term_matrix,
591 &batch_gamma,
592 topic_word_prior,
593 learning_rate,
594 n_samples,
595 )?;
596
597 total_bound += batch_bound;
598 }
599
600 if let Some(ref mut bound) = self.bound {
602 bound.push(total_bound / n_samples as f64);
603 }
604
605 if let Some(ref bound) = self.bound {
607 if bound.len() > 1 {
608 let current_bound = bound[bound.len() - 1];
609 let prev_bound = bound[bound.len() - 2];
610 let change = (current_bound - prev_bound).abs();
611 if change < self.config.mean_change_tol {
612 break;
613 }
614 }
615 }
616
617 self.n_iter = epoch + 1;
618 }
619
620 self.n_documents = n_samples;
621 Ok(())
622 }
623
624 fn compute_learning_rate(&self, iteration: usize) -> f64 {
626 (self.config.learning_offset + iteration as f64).powf(-self.config.learning_decay)
627 }
628
629 fn update_topic_word_distribution(
631 &mut self,
632 batch_docs: &[usize],
633 doc_term_matrix: &Array2<f64>,
634 batch_gamma: &Array2<f64>,
635 topic_word_prior: f64,
636 learning_rate: f64,
637 total_docs: usize,
638 ) -> Result<()> {
639 let batch_size = batch_docs.len();
640 let n_features = doc_term_matrix.ncols();
641
642 if let Some(ref mut components) = self.components {
643 let mut batch_stats = Array2::<f64>::zeros((self.config.ntopics, n_features));
645
646 for (local_idx, &doc_idx) in batch_docs.iter().enumerate() {
647 let doc = doc_term_matrix.row(doc_idx);
648 let gamma = batch_gamma.row(local_idx);
649 let gamma_sum = gamma.sum();
650
651 for (word_idx, &count) in doc.iter().enumerate() {
652 if count > 0.0 {
653 for topic_idx in 0..self.config.ntopics {
654 let phi = gamma[topic_idx] / gamma_sum;
655 batch_stats[[topic_idx, word_idx]] += count * phi;
656 }
657 }
658 }
659 }
660
661 let scale_factor = total_docs as f64 / batch_size as f64;
663 batch_stats.mapv_inplace(|x| x * scale_factor);
664
665 for topic_idx in 0..self.config.ntopics {
667 for word_idx in 0..n_features {
668 let old_val = components[[topic_idx, word_idx]];
669 let new_val = topic_word_prior + batch_stats[[topic_idx, word_idx]];
670 components[[topic_idx, word_idx]] =
671 (1.0 - learning_rate) * old_val + learning_rate * new_val;
672 }
673 }
674 }
675
676 Ok(())
677 }
678
679 fn update_doc_distribution(
680 &self,
681 doc: &Array1<f64>,
682 gamma: &mut Array1<f64>,
683 exp_topic_word_distr: &Array2<f64>,
684 doc_topic_prior: f64,
685 ) -> Result<()> {
686 for _ in 0..self.config.max_doc_update_iter {
688 let old_gamma = gamma.clone();
689
690 gamma.fill(doc_topic_prior);
692
693 for (word_idx, &count) in doc.iter().enumerate() {
695 }
697
698 let change: f64 = (&*gamma - &old_gamma).iter().map(|&x| x.abs()).sum();
700 if change < self.config.mean_change_tol {
701 break;
702 }
703 }
704
705 Ok(())
706 }
707
708 fn update_topic_distribution(
709 &mut self,
710 doc_term_matrix: &Array2<f64>,
711 doc_topic_distr: &Array2<f64>,
712 topic_word_prior: f64,
713 ) -> Result<()> {
714 if let Some(ref mut components) = self.components {
715 let _n_features = doc_term_matrix.ncols();
716
717 components.fill(topic_word_prior);
719
720 for (doc_idx, doc) in doc_term_matrix.axis_iter(Axis(0)).enumerate() {
722 let doc_topics = doc_topic_distr.row(doc_idx);
723
724 for (word_idx, &count) in doc.iter().enumerate() {
725 if count > 0.0 {
726 for topic_idx in 0..self.config.ntopics {
727 components[[topic_idx, word_idx]] += count * doc_topics[topic_idx];
728 }
729 }
730 }
731 }
732
733 for mut topic in components.axis_iter_mut(Axis(0)) {
735 let topic_sum = topic.sum();
736 if topic_sum > 0.0 {
737 topic /= topic_sum;
738 }
739 }
740 }
741
742 Ok(())
743 }
744
745 fn update_exp_dirichlet_component(&mut self) -> Result<()> {
746 if let Some(ref components) = self.components {
747 self.exp_dirichlet_component = Some(components.clone());
750 }
751 Ok(())
752 }
753}
754
755pub struct LdaBuilder {
757 config: LdaConfig,
758}
759
760impl LdaBuilder {
761 pub fn new() -> Self {
763 Self {
764 config: LdaConfig::default(),
765 }
766 }
767
768 pub fn ntopics(mut self, ntopics: usize) -> Self {
770 self.config.ntopics = ntopics;
771 self
772 }
773
774 pub fn doc_topic_prior(mut self, prior: f64) -> Self {
776 self.config.doc_topic_prior = Some(prior);
777 self
778 }
779
780 pub fn topic_word_prior(mut self, prior: f64) -> Self {
782 self.config.topic_word_prior = Some(prior);
783 self
784 }
785
786 pub fn learning_method(mut self, method: LdaLearningMethod) -> Self {
788 self.config.learning_method = method;
789 self
790 }
791
792 pub fn maxiter(mut self, maxiter: usize) -> Self {
794 self.config.maxiter = maxiter;
795 self
796 }
797
798 pub fn random_seed(mut self, seed: u64) -> Self {
800 self.config.random_seed = Some(seed);
801 self
802 }
803
804 pub fn build(self) -> LatentDirichletAllocation {
806 LatentDirichletAllocation::new(self.config)
807 }
808}
809
810impl Default for LdaBuilder {
811 fn default() -> Self {
812 Self::new()
813 }
814}
815
816#[cfg(test)]
817mod tests {
818 use super::*;
819
820 #[test]
821 fn test_lda_creation() {
822 let lda = LatentDirichletAllocation::with_ntopics(5);
823 assert_eq!(lda.config.ntopics, 5);
824 }
825
826 #[test]
827 fn test_lda_builder() {
828 let lda = LdaBuilder::new()
829 .ntopics(10)
830 .doc_topic_prior(0.1)
831 .maxiter(20)
832 .random_seed(42)
833 .build();
834
835 assert_eq!(lda.config.ntopics, 10);
836 assert_eq!(lda.config.doc_topic_prior, Some(0.1));
837 assert_eq!(lda.config.maxiter, 20);
838 assert_eq!(lda.config.random_seed, Some(42));
839 }
840
841 #[test]
842 fn test_lda_fit_transform() {
843 let doc_term_matrix = Array2::from_shape_vec(
845 (4, 6),
846 vec![
847 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, ],
852 )
853 .expect("Operation failed");
854
855 let mut lda = LatentDirichletAllocation::with_ntopics(2);
856 let doc_topics = lda
857 .fit_transform(&doc_term_matrix)
858 .expect("Operation failed");
859
860 assert_eq!(doc_topics.nrows(), 4);
861 assert_eq!(doc_topics.ncols(), 2);
862
863 for row in doc_topics.axis_iter(Axis(0)) {
865 let sum: f64 = row.sum();
866 assert!((sum - 1.0).abs() < 1e-6);
867 }
868 }
869
870 #[test]
871 fn test_get_topics() {
872 let doc_term_matrix = Array2::from_shape_vec(
873 (4, 3),
874 vec![2.0, 1.0, 0.0, 0.0, 2.0, 1.0, 1.0, 0.0, 2.0, 2.0, 1.0, 1.0],
875 )
876 .expect("Operation failed");
877
878 let mut vocabulary = HashMap::new();
879 vocabulary.insert(0, "word1".to_string());
880 vocabulary.insert(1, "word2".to_string());
881 vocabulary.insert(2, "word3".to_string());
882
883 let mut lda = LatentDirichletAllocation::with_ntopics(2);
884 lda.fit(&doc_term_matrix).expect("Operation failed");
885
886 let topics = lda.get_topics(3, &vocabulary).expect("Operation failed");
887 assert_eq!(topics.len(), 2);
888
889 for topic in &topics {
890 assert_eq!(topic.top_words.len(), 3);
891 }
892 }
893}