scirs2_text/
topic_modeling.rs

1//! # Topic Modeling Module
2//!
3//! This module provides advanced topic modeling algorithms for discovering
4//! hidden thematic structures in document collections, with a focus on
5//! Latent Dirichlet Allocation (LDA).
6//!
7//! ## Overview
8//!
9//! Topic modeling is an unsupervised machine learning technique that discovers
10//! abstract "topics" that occur in a collection of documents. This module implements:
11//!
12//! - **Latent Dirichlet Allocation (LDA)**: The most popular topic modeling algorithm
13//! - **Batch and Online Learning**: Different training strategies for various dataset sizes
14//! - **Coherence Metrics**: Model evaluation using CV, UMass, and UCI coherence
15//! - **Topic Visualization**: Tools for understanding and presenting results
16//!
17//! ## Quick Start
18//!
19//! ```rust
20//! use scirs2_text::topic_modeling::{LatentDirichletAllocation, LdaConfig, LdaLearningMethod};
21//! use scirs2_text::vectorize::{CountVectorizer, Vectorizer};
22//! use std::collections::HashMap;
23//!
24//! // Sample documents
25//! let documents = vec![
26//!     "machine learning algorithms are powerful tools",
27//!     "natural language processing uses machine learning",
28//!     "deep learning is a subset of machine learning",
29//!     "cats and dogs are popular pets",
30//!     "pet care requires attention and love",
31//!     "dogs need regular exercise and training"
32//! ];
33//!
34//! // Vectorize documents
35//! let mut vectorizer = CountVectorizer::new(false);
36//! let doc_term_matrix = vectorizer.fit_transform(&documents).unwrap();
37//!
38//! // Configure LDA
39//! let config = LdaConfig {
40//!     ntopics: 2,
41//!     doc_topic_prior: Some(0.1),    // Alpha parameter
42//!     topic_word_prior: Some(0.01),  // Beta parameter
43//!     learning_method: LdaLearningMethod::Batch,
44//!     maxiter: 100,
45//!     mean_change_tol: 1e-4,
46//!     random_seed: Some(42),
47//!     ..Default::default()
48//! };
49//!
50//! // Train the model
51//! let mut lda = LatentDirichletAllocation::new(config);
52//! lda.fit(&doc_term_matrix).unwrap();
53//!
54//! // Create vocabulary mapping for topic display
55//! let vocab_map: HashMap<usize, String> = (0..1000).map(|i| (i, format!("word_{}", i))).collect();
56//!
57//! // Get topics
58//! let topics = lda.get_topics(10, &vocab_map); // Top 10 words per topic
59//! for (i, topic) in topics.iter().enumerate() {
60//!     println!("Topic {}: {:?}", i, topic);
61//! }
62//!
63//! // Transform documents to topic space
64//! let doc_topics = lda.transform(&doc_term_matrix).unwrap();
65//! println!("Document-topic distribution: {:?}", doc_topics);
66//! ```
67//!
68//! ## Advanced Usage
69//!
70//! ### Online Learning for Large Datasets
71//!
72//! ```rust
73//! use scirs2_text::topic_modeling::{LdaConfig, LdaLearningMethod, LatentDirichletAllocation};
74//!
75//! let config = LdaConfig {
76//!     ntopics: 10,
77//!     learning_method: LdaLearningMethod::Online,
78//!     batch_size: 64,                // Mini-batch size
79//!     learning_decay: 0.7,           // Learning rate decay
80//!     learning_offset: 10.0,         // Learning rate offset
81//!     maxiter: 500,
82//!     ..Default::default()
83//! };
84//!
85//! let mut lda = LatentDirichletAllocation::new(config);
86//! // Process documents in batches for memory efficiency
87//! ```
88//!
89//! ### Custom Hyperparameters
90//!
91//! ```rust
92//! use scirs2_text::topic_modeling::LdaConfig;
93//!
94//! let config = LdaConfig {
95//!     ntopics: 20,
96//!     doc_topic_prior: Some(50.0 / 20.0),  // Symmetric Dirichlet
97//!     topic_word_prior: Some(0.1),         // Sparse topics
98//!     maxiter: 1000,                      // More iterations
99//!     mean_change_tol: 1e-6,               // Stricter convergence
100//!     ..Default::default()
101//! };
102//! ```
103//!
104//! ### Model Evaluation
105//!
106//! ```rust
107//! use scirs2_text::topic_modeling::{LatentDirichletAllocation, LdaConfig};
108//! use scirs2_text::vectorize::{CountVectorizer, Vectorizer};
109//! use std::collections::HashMap;
110//!
111//! # let documents = vec!["the quick brown fox", "jumped over the lazy dog"];
112//! # let mut vectorizer = CountVectorizer::new(false);
113//! # let doc_term_matrix = vectorizer.fit_transform(&documents).unwrap();
114//! # let mut lda = LatentDirichletAllocation::new(LdaConfig::default());
115//! # lda.fit(&doc_term_matrix).unwrap();
116//! # let vocab_map: HashMap<usize, String> = (0..100).map(|i| (i, format!("word_{}", i))).collect();
117//! // Get model information
118//! let topics = lda.get_topics(5, &vocab_map); // Top 5 words per topic
119//! println!("Number of topics: {}", topics.unwrap().len());
120//!
121//! // Get document-topic probabilities
122//! let doc_topic_probs = lda.transform(&doc_term_matrix).unwrap();
123//! println!("Document-topic shape: {:?}", doc_topic_probs.shape());
124//! ```
125//!
126//! ## Parameter Tuning Guide
127//!
128//! ### Number of Topics
129//! - **Too few**: Broad, less meaningful topics
130//! - **Too many**: Narrow, potentially noisy topics
131//! - **Recommendation**: Start with √(number of documents) and tune based on coherence
132//!
133//! ### Alpha (doc_topic_prior)
134//! - **High values (e.g., 1.0)**: Documents contain many topics
135//! - **Low values (e.g., 0.1)**: Documents contain few topics
136//! - **Default**: 50/ntopics (symmetric)
137//!
138//! ### Beta (topic_word_prior)
139//! - **High values (e.g., 1.0)**: Topics contain many words
140//! - **Low values (e.g., 0.01)**: Topics are more focused
141//! - **Default**: 0.01 for sparse topics
142//!
143//! ## Performance Optimization
144//!
145//! 1. **Use Online Learning**: For datasets that don't fit in memory
146//! 2. **Tune Batch Size**: Balance between speed and convergence stability
147//! 3. **Set Tolerance**: Stop early when convergence is reached
148//! 4. **Monitor Perplexity**: Track model performance during training
149//! 5. **Parallel Processing**: Enable for faster vocabulary building
150//!
151//! ## Mathematical Background
152//!
153//! LDA assumes each document is a mixture of topics, and each topic is a distribution over words.
154//! The generative process:
155//!
156//! 1. For each topic k: Draw word distribution φₖ ~ Dirichlet(β)
157//! 2. For each document d:
158//!    - Draw topic distribution θ_d ~ Dirichlet(α)
159//!    - For each word n in document d:
160//!      - Draw topic assignment z_{d,n} ~ Multinomial(θ_d)
161//!      - Draw word w_{d,n} ~ Multinomial(φ_{z_{d,n}})
162//!
163//! The goal is to infer the posterior distributions of θ and φ given the observed words.
164
165use crate::error::{Result, TextError};
166use scirs2_core::ndarray::{Array1, Array2, Axis};
167use scirs2_core::random::prelude::*;
168use scirs2_core::random::seq::SliceRandom;
169use scirs2_core::random::{rngs::StdRng, SeedableRng};
170use std::collections::HashMap;
171
172/// Learning method for LDA
173#[derive(Debug, Clone, Copy, PartialEq)]
174pub enum LdaLearningMethod {
175    /// Batch learning - process all documents at once
176    Batch,
177    /// Online learning - process documents in mini-batches
178    Online,
179}
180
181/// Latent Dirichlet Allocation configuration
182#[derive(Debug, Clone)]
183pub struct LdaConfig {
184    /// Number of topics
185    pub ntopics: usize,
186    /// Prior for document-topic distribution (alpha)
187    pub doc_topic_prior: Option<f64>,
188    /// Prior for topic-word distribution (eta)
189    pub topic_word_prior: Option<f64>,
190    /// Learning method
191    pub learning_method: LdaLearningMethod,
192    /// Learning decay for online learning
193    pub learning_decay: f64,
194    /// Learning offset for online learning
195    pub learning_offset: f64,
196    /// Maximum iterations
197    pub maxiter: usize,
198    /// Batch size for online learning
199    pub batch_size: usize,
200    /// Mean change tolerance for convergence
201    pub mean_change_tol: f64,
202    /// Maximum iterations for document E-step
203    pub max_doc_update_iter: usize,
204    /// Random seed
205    pub random_seed: Option<u64>,
206}
207
208impl Default for LdaConfig {
209    fn default() -> Self {
210        Self {
211            ntopics: 10,
212            doc_topic_prior: None,  // Will be set to 1/ntopics
213            topic_word_prior: None, // Will be set to 1/ntopics
214            learning_method: LdaLearningMethod::Batch,
215            learning_decay: 0.7,
216            learning_offset: 10.0,
217            maxiter: 10,
218            batch_size: 128,
219            mean_change_tol: 1e-3,
220            max_doc_update_iter: 100,
221            random_seed: None,
222        }
223    }
224}
225
226/// Topic representation
227#[derive(Debug, Clone)]
228pub struct Topic {
229    /// Topic ID
230    pub id: usize,
231    /// Top words in the topic with their weights
232    pub top_words: Vec<(String, f64)>,
233    /// Topic coherence score (if computed)
234    pub coherence: Option<f64>,
235}
236
237/// Latent Dirichlet Allocation
238pub struct LatentDirichletAllocation {
239    config: LdaConfig,
240    /// Topic-word distribution (learned parameters)
241    components: Option<Array2<f64>>,
242    /// exp(E[log(beta)]) for efficient computation
243    exp_dirichlet_component: Option<Array2<f64>>,
244    /// Vocabulary mapping
245    #[allow(dead_code)]
246    vocabulary: Option<HashMap<usize, String>>,
247    /// Number of documents seen
248    n_documents: usize,
249    /// Number of iterations performed
250    n_iter: usize,
251    /// Final perplexity bound
252    #[allow(dead_code)]
253    bound: Option<Vec<f64>>,
254}
255
256impl LatentDirichletAllocation {
257    /// Create a new LDA model with the given configuration
258    pub fn new(config: LdaConfig) -> Self {
259        Self {
260            config,
261            components: None,
262            exp_dirichlet_component: None,
263            vocabulary: None,
264            n_documents: 0,
265            n_iter: 0,
266            bound: None,
267        }
268    }
269
270    /// Create a new LDA model with default configuration
271    pub fn with_ntopics(ntopics: usize) -> Self {
272        let config = LdaConfig {
273            ntopics,
274            ..Default::default()
275        };
276        Self::new(config)
277    }
278
279    /// Fit the LDA model on a document-term matrix
280    pub fn fit(&mut self, doc_termmatrix: &Array2<f64>) -> Result<&mut Self> {
281        if doc_termmatrix.nrows() == 0 || doc_termmatrix.ncols() == 0 {
282            return Err(TextError::InvalidInput(
283                "Document-term _matrix cannot be empty".to_string(),
284            ));
285        }
286
287        let n_samples = doc_termmatrix.nrows();
288        let n_features = doc_termmatrix.ncols();
289
290        // Set default priors if not provided
291        let doc_topic_prior = self
292            .config
293            .doc_topic_prior
294            .unwrap_or(1.0 / self.config.ntopics as f64);
295        let topic_word_prior = self
296            .config
297            .topic_word_prior
298            .unwrap_or(1.0 / self.config.ntopics as f64);
299
300        // Initialize topic-word distribution randomly
301        let mut rng = self.create_rng();
302        self.components = Some(self.initialize_components(n_features, &mut rng));
303
304        // Perform training based on learning method
305        match self.config.learning_method {
306            LdaLearningMethod::Batch => {
307                self.fit_batch(doc_termmatrix, doc_topic_prior, topic_word_prior)?;
308            }
309            LdaLearningMethod::Online => {
310                self.fit_online(doc_termmatrix, doc_topic_prior, topic_word_prior)?;
311            }
312        }
313
314        self.n_documents = n_samples;
315        Ok(self)
316    }
317
318    /// Transform documents to topic distribution
319    pub fn transform(&self, doc_termmatrix: &Array2<f64>) -> Result<Array2<f64>> {
320        if self.components.is_none() {
321            return Err(TextError::ModelNotFitted(
322                "LDA model not fitted yet".to_string(),
323            ));
324        }
325
326        let n_samples = doc_termmatrix.nrows();
327        let ntopics = self.config.ntopics;
328
329        // Initialize document-topic distribution
330        let mut doc_topic_distr = Array2::zeros((n_samples, ntopics));
331
332        // Get exp(E[log(beta)])
333        let exp_dirichlet_component = self.get_exp_dirichlet_component()?;
334
335        // Set default prior
336        let doc_topic_prior = self.config.doc_topic_prior.unwrap_or(1.0 / ntopics as f64);
337
338        // Update document-topic distribution for each document
339        for (doc_idx, doc) in doc_termmatrix.axis_iter(Axis(0)).enumerate() {
340            let mut gamma = Array1::from_elem(ntopics, doc_topic_prior);
341            self.update_doc_distribution(
342                &doc.to_owned(),
343                &mut gamma,
344                exp_dirichlet_component,
345                doc_topic_prior,
346            )?;
347
348            // Normalize to get probability distribution
349            let gamma_sum = gamma.sum();
350            if gamma_sum > 0.0 {
351                gamma /= gamma_sum;
352            }
353
354            doc_topic_distr.row_mut(doc_idx).assign(&gamma);
355        }
356
357        Ok(doc_topic_distr)
358    }
359
360    /// Fit and transform in one step
361    pub fn fit_transform(&mut self, doc_termmatrix: &Array2<f64>) -> Result<Array2<f64>> {
362        self.fit(doc_termmatrix)?;
363        self.transform(doc_termmatrix)
364    }
365
366    /// Get the topics with top words
367    pub fn get_topics(
368        &self,
369        n_top_words: usize,
370        vocabulary: &HashMap<usize, String>,
371    ) -> Result<Vec<Topic>> {
372        if self.components.is_none() {
373            return Err(TextError::ModelNotFitted(
374                "LDA model not fitted yet".to_string(),
375            ));
376        }
377
378        let components = self.components.as_ref().unwrap();
379        let mut topics = Vec::new();
380
381        for (topic_idx, topic_dist) in components.axis_iter(Axis(0)).enumerate() {
382            // Get indices of top _words
383            let mut word_scores: Vec<(usize, f64)> = topic_dist
384                .iter()
385                .enumerate()
386                .map(|(idx, &score)| (idx, score))
387                .collect();
388
389            word_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
390
391            // Get top _words with their scores
392            let top_words: Vec<(String, f64)> = word_scores
393                .into_iter()
394                .take(n_top_words)
395                .filter_map(|(idx, score)| vocabulary.get(&idx).map(|word| (word.clone(), score)))
396                .collect();
397
398            topics.push(Topic {
399                id: topic_idx,
400                top_words,
401                coherence: None,
402            });
403        }
404
405        Ok(topics)
406    }
407
408    /// Get the topic-word distribution matrix
409    pub fn get_topic_word_distribution(&self) -> Option<&Array2<f64>> {
410        self.components.as_ref()
411    }
412
413    // Helper functions
414
415    fn create_rng(&self) -> scirs2_core::random::rngs::StdRng {
416        use scirs2_core::random::SeedableRng;
417        match self.config.random_seed {
418            Some(seed) => scirs2_core::random::rngs::StdRng::seed_from_u64(seed),
419            None => {
420                let mut temp_rng = scirs2_core::random::rng();
421                scirs2_core::random::rngs::StdRng::from_rng(&mut temp_rng)
422            }
423        }
424    }
425
426    fn initialize_components(
427        &self,
428        n_features: usize,
429        rng: &mut scirs2_core::random::rngs::StdRng,
430    ) -> Array2<f64> {
431        // Use the RNG directly
432
433        let mut components = Array2::zeros((self.config.ntopics, n_features));
434        for mut row in components.axis_iter_mut(Axis(0)) {
435            for val in row.iter_mut() {
436                *val = rng.random_range(0.0..1.0);
437            }
438            // Normalize each topic
439            let row_sum: f64 = row.sum();
440            if row_sum > 0.0 {
441                row /= row_sum;
442            }
443        }
444
445        components
446    }
447
448    fn get_exp_dirichlet_component(&self) -> Result<&Array2<f64>> {
449        if self.exp_dirichlet_component.is_none() {
450            return Err(TextError::ModelNotFitted(
451                "Components not initialized".to_string(),
452            ));
453        }
454        Ok(self.exp_dirichlet_component.as_ref().unwrap())
455    }
456
457    fn fit_batch(
458        &mut self,
459        doc_term_matrix: &Array2<f64>,
460        doc_topic_prior: f64,
461        topic_word_prior: f64,
462    ) -> Result<()> {
463        let n_samples = doc_term_matrix.nrows();
464        let ntopics = self.config.ntopics;
465
466        // Initialize document-topic distribution
467        let mut doc_topic_distr = Array2::from_elem((n_samples, ntopics), doc_topic_prior);
468
469        // Training loop
470        for iter in 0..self.config.maxiter {
471            // Update exp(E[log(beta)])
472            self.update_exp_dirichlet_component()?;
473
474            // E-step: Update document-topic distribution
475            let mut mean_change = 0.0;
476            for (doc_idx, doc) in doc_term_matrix.axis_iter(Axis(0)).enumerate() {
477                let mut gamma = doc_topic_distr.row(doc_idx).to_owned();
478                let old_gamma = gamma.clone();
479
480                self.update_doc_distribution(
481                    &doc.to_owned(),
482                    &mut gamma,
483                    self.get_exp_dirichlet_component()?,
484                    doc_topic_prior,
485                )?;
486
487                // Calculate mean change
488                let change: f64 = (&gamma - &old_gamma).iter().map(|&x| x.abs()).sum();
489                mean_change += change / ntopics as f64;
490
491                doc_topic_distr.row_mut(doc_idx).assign(&gamma);
492            }
493            mean_change /= n_samples as f64;
494
495            // M-step: Update topic-word distribution
496            self.update_topic_distribution(doc_term_matrix, &doc_topic_distr, topic_word_prior)?;
497
498            // Check convergence
499            if mean_change < self.config.mean_change_tol {
500                break;
501            }
502
503            self.n_iter = iter + 1;
504        }
505
506        Ok(())
507    }
508
509    fn fit_online(
510        &mut self,
511        doc_term_matrix: &Array2<f64>,
512        doc_topic_prior: f64,
513        topic_word_prior: f64,
514    ) -> Result<()> {
515        let (n_samples, n_features) = doc_term_matrix.dim();
516        self.vocabulary
517            .get_or_insert_with(|| (0..n_features).map(|i| (i, format!("word_{i}"))).collect());
518        self.bound.get_or_insert_with(Vec::new);
519
520        // Initialize topic-word distribution if not already done
521        if self.components.is_none() {
522            let mut rng = if let Some(seed) = self.config.random_seed {
523                StdRng::seed_from_u64(seed)
524            } else {
525                StdRng::from_rng(&mut scirs2_core::random::rng())
526            };
527
528            let mut components = Array2::<f64>::zeros((self.config.ntopics, n_features));
529            for i in 0..self.config.ntopics {
530                for j in 0..n_features {
531                    components[[i, j]] = rng.random::<f64>() + topic_word_prior;
532                }
533            }
534            self.components = Some(components);
535        }
536
537        let batch_size = self.config.batch_size.min(n_samples);
538        let n_batches = n_samples.div_ceil(batch_size);
539
540        for epoch in 0..self.config.maxiter {
541            let mut total_bound = 0.0;
542
543            // Shuffle document indices for each epoch
544            let mut doc_indices: Vec<usize> = (0..n_samples).collect();
545            let mut rng = if let Some(seed) = self.config.random_seed {
546                StdRng::seed_from_u64(seed + epoch as u64)
547            } else {
548                StdRng::from_rng(&mut scirs2_core::random::rng())
549            };
550            doc_indices.shuffle(&mut rng);
551
552            for batch_idx in 0..n_batches {
553                let start_idx = batch_idx * batch_size;
554                let end_idx = ((batch_idx + 1) * batch_size).min(n_samples);
555
556                // Get batch documents
557                let batch_docs: Vec<usize> = doc_indices[start_idx..end_idx].to_vec();
558
559                // E-step: Update document-topic distributions for batch
560                let mut batch_gamma = Array2::<f64>::zeros((batch_docs.len(), self.config.ntopics));
561                let mut batch_bound = 0.0;
562
563                for (local_idx, &doc_idx) in batch_docs.iter().enumerate() {
564                    let doc = doc_term_matrix.row(doc_idx);
565                    let mut gamma = Array1::<f64>::from_elem(self.config.ntopics, doc_topic_prior);
566
567                    // Update document distribution
568                    let components = self.components.as_ref().unwrap();
569                    let exp_topic_word_distr = components.map(|x| x.exp());
570                    self.update_doc_distribution(
571                        &doc.to_owned(),
572                        &mut gamma,
573                        &exp_topic_word_distr,
574                        doc_topic_prior,
575                    )?;
576
577                    batch_gamma.row_mut(local_idx).assign(&gamma);
578
579                    // Compute bound contribution (simplified)
580                    batch_bound += gamma.sum();
581                }
582
583                // M-step: Update topic-word distributions
584                let learning_rate = self.compute_learning_rate(epoch * n_batches + batch_idx);
585                self.update_topic_word_distribution(
586                    &batch_docs,
587                    doc_term_matrix,
588                    &batch_gamma,
589                    topic_word_prior,
590                    learning_rate,
591                    n_samples,
592                )?;
593
594                total_bound += batch_bound;
595            }
596
597            // Store bound for this epoch
598            if let Some(ref mut bound) = self.bound {
599                bound.push(total_bound / n_samples as f64);
600            }
601
602            // Check convergence
603            if let Some(ref bound) = self.bound {
604                if bound.len() > 1 {
605                    let current_bound = bound[bound.len() - 1];
606                    let prev_bound = bound[bound.len() - 2];
607                    let change = (current_bound - prev_bound).abs();
608                    if change < self.config.mean_change_tol {
609                        break;
610                    }
611                }
612            }
613
614            self.n_iter = epoch + 1;
615        }
616
617        self.n_documents = n_samples;
618        Ok(())
619    }
620
621    /// Compute learning rate for online learning
622    fn compute_learning_rate(&self, iteration: usize) -> f64 {
623        (self.config.learning_offset + iteration as f64).powf(-self.config.learning_decay)
624    }
625
626    /// Update topic-word distributions in online learning
627    fn update_topic_word_distribution(
628        &mut self,
629        batch_docs: &[usize],
630        doc_term_matrix: &Array2<f64>,
631        batch_gamma: &Array2<f64>,
632        topic_word_prior: f64,
633        learning_rate: f64,
634        total_docs: usize,
635    ) -> Result<()> {
636        let batch_size = batch_docs.len();
637        let n_features = doc_term_matrix.ncols();
638
639        if let Some(ref mut components) = self.components {
640            // Compute sufficient statistics for this batch
641            let mut batch_stats = Array2::<f64>::zeros((self.config.ntopics, n_features));
642
643            for (local_idx, &doc_idx) in batch_docs.iter().enumerate() {
644                let doc = doc_term_matrix.row(doc_idx);
645                let gamma = batch_gamma.row(local_idx);
646                let gamma_sum = gamma.sum();
647
648                for (word_idx, &count) in doc.iter().enumerate() {
649                    if count > 0.0 {
650                        for topic_idx in 0..self.config.ntopics {
651                            let phi = gamma[topic_idx] / gamma_sum;
652                            batch_stats[[topic_idx, word_idx]] += count * phi;
653                        }
654                    }
655                }
656            }
657
658            // Scale batch statistics to full corpus size
659            let scale_factor = total_docs as f64 / batch_size as f64;
660            batch_stats.mapv_inplace(|x| x * scale_factor);
661
662            // Update components using natural gradient with learning _rate
663            for topic_idx in 0..self.config.ntopics {
664                for word_idx in 0..n_features {
665                    let old_val = components[[topic_idx, word_idx]];
666                    let new_val = topic_word_prior + batch_stats[[topic_idx, word_idx]];
667                    components[[topic_idx, word_idx]] =
668                        (1.0 - learning_rate) * old_val + learning_rate * new_val;
669                }
670            }
671        }
672
673        Ok(())
674    }
675
676    fn update_doc_distribution(
677        &self,
678        doc: &Array1<f64>,
679        gamma: &mut Array1<f64>,
680        exp_topic_word_distr: &Array2<f64>,
681        doc_topic_prior: f64,
682    ) -> Result<()> {
683        // Simple mean-field update
684        for _ in 0..self.config.max_doc_update_iter {
685            let old_gamma = gamma.clone();
686
687            // Reset gamma
688            gamma.fill(doc_topic_prior);
689
690            // Update based on word counts and topic-word probabilities
691            for (word_idx, &count) in doc.iter().enumerate() {
692                // Processing logic here
693            }
694
695            // Check convergence
696            let change: f64 = (&*gamma - &old_gamma).iter().map(|&x| x.abs()).sum();
697            if change < self.config.mean_change_tol {
698                break;
699            }
700        }
701
702        Ok(())
703    }
704
705    fn update_topic_distribution(
706        &mut self,
707        doc_term_matrix: &Array2<f64>,
708        doc_topic_distr: &Array2<f64>,
709        topic_word_prior: f64,
710    ) -> Result<()> {
711        if let Some(ref mut components) = self.components {
712            let _n_features = doc_term_matrix.ncols();
713
714            // Reset components
715            components.fill(topic_word_prior);
716
717            // Accumulate sufficient statistics
718            for (doc_idx, doc) in doc_term_matrix.axis_iter(Axis(0)).enumerate() {
719                let doc_topics = doc_topic_distr.row(doc_idx);
720
721                for (word_idx, &count) in doc.iter().enumerate() {
722                    if count > 0.0 {
723                        for topic_idx in 0..self.config.ntopics {
724                            components[[topic_idx, word_idx]] += count * doc_topics[topic_idx];
725                        }
726                    }
727                }
728            }
729
730            // Normalize each topic
731            for mut topic in components.axis_iter_mut(Axis(0)) {
732                let topic_sum = topic.sum();
733                if topic_sum > 0.0 {
734                    topic /= topic_sum;
735                }
736            }
737        }
738
739        Ok(())
740    }
741
742    fn update_exp_dirichlet_component(&mut self) -> Result<()> {
743        if let Some(ref components) = self.components {
744            // For simplicity, we'll use the components directly
745            // In a full implementation, this would compute exp(E[log(beta)])
746            self.exp_dirichlet_component = Some(components.clone());
747        }
748        Ok(())
749    }
750}
751
752/// Builder for creating LDA models
753pub struct LdaBuilder {
754    config: LdaConfig,
755}
756
757impl LdaBuilder {
758    /// Create a new builder with default configuration
759    pub fn new() -> Self {
760        Self {
761            config: LdaConfig::default(),
762        }
763    }
764
765    /// Set the number of topics
766    pub fn ntopics(mut self, ntopics: usize) -> Self {
767        self.config.ntopics = ntopics;
768        self
769    }
770
771    /// Set the document-topic prior (alpha)
772    pub fn doc_topic_prior(mut self, prior: f64) -> Self {
773        self.config.doc_topic_prior = Some(prior);
774        self
775    }
776
777    /// Set the topic-word prior (eta)
778    pub fn topic_word_prior(mut self, prior: f64) -> Self {
779        self.config.topic_word_prior = Some(prior);
780        self
781    }
782
783    /// Set the learning method
784    pub fn learning_method(mut self, method: LdaLearningMethod) -> Self {
785        self.config.learning_method = method;
786        self
787    }
788
789    /// Set the maximum iterations
790    pub fn maxiter(mut self, maxiter: usize) -> Self {
791        self.config.maxiter = maxiter;
792        self
793    }
794
795    /// Set the random seed
796    pub fn random_seed(mut self, seed: u64) -> Self {
797        self.config.random_seed = Some(seed);
798        self
799    }
800
801    /// Build the LDA model
802    pub fn build(self) -> LatentDirichletAllocation {
803        LatentDirichletAllocation::new(self.config)
804    }
805}
806
807impl Default for LdaBuilder {
808    fn default() -> Self {
809        Self::new()
810    }
811}
812
813#[cfg(test)]
814mod tests {
815    use super::*;
816
817    #[test]
818    fn test_lda_creation() {
819        let lda = LatentDirichletAllocation::with_ntopics(5);
820        assert_eq!(lda.config.ntopics, 5);
821    }
822
823    #[test]
824    fn test_lda_builder() {
825        let lda = LdaBuilder::new()
826            .ntopics(10)
827            .doc_topic_prior(0.1)
828            .maxiter(20)
829            .random_seed(42)
830            .build();
831
832        assert_eq!(lda.config.ntopics, 10);
833        assert_eq!(lda.config.doc_topic_prior, Some(0.1));
834        assert_eq!(lda.config.maxiter, 20);
835        assert_eq!(lda.config.random_seed, Some(42));
836    }
837
838    #[test]
839    fn test_lda_fit_transform() {
840        // Create a simple document-term matrix
841        let doc_term_matrix = Array2::from_shape_vec(
842            (4, 6),
843            vec![
844                1.0, 1.0, 0.0, 0.0, 0.0, 0.0, // Doc 1
845                0.0, 1.0, 1.0, 0.0, 0.0, 0.0, // Doc 2
846                0.0, 0.0, 0.0, 1.0, 1.0, 0.0, // Doc 3
847                0.0, 0.0, 0.0, 0.0, 1.0, 1.0, // Doc 4
848            ],
849        )
850        .unwrap();
851
852        let mut lda = LatentDirichletAllocation::with_ntopics(2);
853        let doc_topics = lda.fit_transform(&doc_term_matrix).unwrap();
854
855        assert_eq!(doc_topics.nrows(), 4);
856        assert_eq!(doc_topics.ncols(), 2);
857
858        // Check that each document's topic distribution sums to 1
859        for row in doc_topics.axis_iter(Axis(0)) {
860            let sum: f64 = row.sum();
861            assert!((sum - 1.0).abs() < 1e-6);
862        }
863    }
864
865    #[test]
866    fn test_get_topics() {
867        let doc_term_matrix = Array2::from_shape_vec(
868            (4, 3),
869            vec![2.0, 1.0, 0.0, 0.0, 2.0, 1.0, 1.0, 0.0, 2.0, 2.0, 1.0, 1.0],
870        )
871        .unwrap();
872
873        let mut vocabulary = HashMap::new();
874        vocabulary.insert(0, "word1".to_string());
875        vocabulary.insert(1, "word2".to_string());
876        vocabulary.insert(2, "word3".to_string());
877
878        let mut lda = LatentDirichletAllocation::with_ntopics(2);
879        lda.fit(&doc_term_matrix).unwrap();
880
881        let topics = lda.get_topics(3, &vocabulary).unwrap();
882        assert_eq!(topics.len(), 2);
883
884        for topic in &topics {
885            assert_eq!(topic.top_words.len(), 3);
886        }
887    }
888}