Skip to main content

scirs2_text/
hdp.rs

1//! Hierarchical Dirichlet Process (HDP) topic model.
2//!
3//! HDP automatically infers the number of topics from data using the
4//! Chinese Restaurant Franchise (CRF) representation with collapsed Gibbs
5//! sampling.
6//!
7//! ## Model
8//!
9//! The generative process mirrors LDA but replaces the fixed-`K` Dirichlet
10//! prior on topic proportions with a DP prior, allowing the number of active
11//! topics to grow with data.
12//!
13//! We follow the *truncated variational Bayes* approximation of
14//! Teh et al. (2006) using a stick-breaking representation with a fixed
15//! truncation level `T`.  A topic is considered "active" when at least one
16//! token has been assigned to it.
17//!
18//! # References
19//! Teh, Y. W., Jordan, M. I., Beal, M. J., & Blei, D. M. (2006).
20//! "Hierarchical Dirichlet Processes." *JASA*, 101(476), 1566–1581.
21//! <https://doi.org/10.1198/016214506000000302>
22
23use crate::error::{Result, TextError};
24use scirs2_core::random::prelude::*;
25use scirs2_core::random::{rngs::StdRng, SeedableRng};
26
27// ── HdpConfig ─────────────────────────────────────────────────────────────────
28
29/// Configuration for the [`HdpModel`].
30#[derive(Debug, Clone)]
31pub struct HdpConfig {
32    /// Truncation level T — maximum number of topics the model can infer.
33    /// Default: 20.
34    pub max_topics: usize,
35    /// Document-level DP concentration parameter α. Default: 1.0.
36    pub alpha: f64,
37    /// Corpus-level DP concentration parameter γ. Default: 1.0.
38    pub gamma: f64,
39    /// Symmetric Dirichlet prior on word distributions η. Default: 0.1.
40    pub eta: f64,
41    /// Number of collapsed Gibbs iterations. Default: 100.
42    pub n_iter: usize,
43    /// Optional random seed for reproducibility.
44    pub seed: Option<u64>,
45}
46
47impl Default for HdpConfig {
48    fn default() -> Self {
49        HdpConfig {
50            max_topics: 20,
51            alpha: 1.0,
52            gamma: 1.0,
53            eta: 0.1,
54            n_iter: 100,
55            seed: None,
56        }
57    }
58}
59
60// ── HdpResult ─────────────────────────────────────────────────────────────────
61
62/// Summary statistics returned by [`HdpModel::fit`].
63#[derive(Debug, Clone)]
64pub struct HdpResult {
65    /// Number of active topics (those with at least one assigned token).
66    pub n_topics: usize,
67    /// Per-held-out-token perplexity estimate (exp(-avg log-likelihood)).
68    pub perplexity: f64,
69    /// Average log-likelihood per token.
70    pub log_likelihood: f64,
71    /// Actual number of Gibbs iterations performed.
72    pub iterations: usize,
73}
74
75// ── HdpModel ──────────────────────────────────────────────────────────────────
76
77/// Hierarchical Dirichlet Process topic model.
78///
79/// Fitted with collapsed Gibbs sampling over the Chinese Restaurant Franchise.
80///
81/// After calling [`fit`](HdpModel::fit) you can:
82/// - Query [`n_topics_active`](HdpModel::n_topics_active) for the inferred
83///   number of topics.
84/// - Call [`transform`](HdpModel::transform) to get topic distributions for
85///   new documents.
86/// - Call [`top_words`](HdpModel::top_words) to inspect topic–word associations.
87/// - Call [`coherence`](HdpModel::coherence) to evaluate topic quality.
88pub struct HdpModel {
89    config: HdpConfig,
90    /// φ[k][w]: normalised topic-word distributions (T × vocab_size).
91    phi: Vec<Vec<f64>>,
92    /// Unnormalised topic-word count matrix (T × vocab_size).
93    topic_word_counts: Vec<Vec<f64>>,
94    /// Number of tokens assigned to each topic.
95    topic_counts: Vec<usize>,
96    /// Number of active topics inferred from data.
97    pub n_topics_active: usize,
98    /// Vocabulary size.
99    vocab_size: usize,
100    /// Whether the model has been fitted.
101    is_fitted: bool,
102}
103
104impl HdpModel {
105    /// Create a new (unfitted) HDP model.
106    pub fn new(config: HdpConfig) -> Self {
107        let t = config.max_topics;
108        HdpModel {
109            config,
110            phi: vec![vec![]; t],
111            topic_word_counts: vec![vec![]; t],
112            topic_counts: vec![0; t],
113            n_topics_active: 0,
114            vocab_size: 0,
115            is_fitted: false,
116        }
117    }
118
119    // ── fit ──────────────────────────────────────────────────────────────
120
121    /// Fit the HDP model to `corpus`.
122    ///
123    /// `corpus` is a slice of documents; each document is a `Vec<usize>`
124    /// of word indices (all indices must be < `vocab_size`).
125    ///
126    /// # Errors
127    /// Returns an error if `corpus` is empty, `vocab_size` is zero, or
128    /// any word index is out of range.
129    pub fn fit(&mut self, corpus: &[Vec<usize>], vocab_size: usize) -> Result<HdpResult> {
130        if corpus.is_empty() {
131            return Err(TextError::InvalidInput(
132                "corpus must not be empty".to_string(),
133            ));
134        }
135        if vocab_size == 0 {
136            return Err(TextError::InvalidInput(
137                "vocab_size must be > 0".to_string(),
138            ));
139        }
140        // Validate indices
141        for (di, doc) in corpus.iter().enumerate() {
142            for &w in doc {
143                if w >= vocab_size {
144                    return Err(TextError::InvalidInput(format!(
145                        "word index {w} in document {di} exceeds vocab_size {vocab_size}"
146                    )));
147                }
148            }
149        }
150
151        self.vocab_size = vocab_size;
152        let t = self.config.max_topics;
153
154        // Initialise count tables
155        self.topic_word_counts = vec![vec![0.0f64; vocab_size]; t];
156        self.topic_counts = vec![0usize; t];
157
158        let mut rng = self.make_rng();
159
160        // ── Initialise topic assignments randomly ─────────────────────
161        let n_docs = corpus.len();
162        // z[d][n] = topic assignment for token n in doc d
163        let mut z: Vec<Vec<usize>> = corpus
164            .iter()
165            .map(|doc| {
166                doc.iter()
167                    .map(|_| rng.random_range(0..t))
168                    .collect::<Vec<usize>>()
169            })
170            .collect();
171
172        // doc-level topic counts: theta_counts[d][k] = #tokens in doc d assigned to topic k
173        let mut theta_counts: Vec<Vec<usize>> = vec![vec![0usize; t]; n_docs];
174
175        // Populate initial counts
176        for (d, doc) in corpus.iter().enumerate() {
177            for (n, &w) in doc.iter().enumerate() {
178                let k = z[d][n];
179                self.topic_word_counts[k][w] += 1.0;
180                self.topic_counts[k] += 1;
181                theta_counts[d][k] += 1;
182            }
183        }
184
185        // ── Collapsed Gibbs sampling ──────────────────────────────────
186        let alpha = self.config.alpha;
187        let eta = self.config.eta;
188        let eta_sum = eta * vocab_size as f64;
189
190        let mut iter_done = 0usize;
191        for _iter in 0..self.config.n_iter {
192            for d in 0..n_docs {
193                for n in 0..corpus[d].len() {
194                    let w = corpus[d][n];
195                    let k_old = z[d][n];
196
197                    // Remove this token's contribution
198                    self.topic_word_counts[k_old][w] -= 1.0;
199                    self.topic_counts[k_old] -= 1;
200                    theta_counts[d][k_old] -= 1;
201
202                    // Compute unnormalised probabilities for each topic
203                    let mut probs = vec![0.0f64; t];
204                    for k in 0..t {
205                        let doc_factor = theta_counts[d][k] as f64 + alpha / t as f64;
206                        let word_factor = (self.topic_word_counts[k][w] + eta)
207                            / (self.topic_counts[k] as f64 + eta_sum);
208                        probs[k] = doc_factor * word_factor;
209                    }
210
211                    // Sample new topic from normalised distribution
212                    let k_new = sample_categorical(&probs, &mut rng);
213
214                    // Update counts
215                    z[d][n] = k_new;
216                    self.topic_word_counts[k_new][w] += 1.0;
217                    self.topic_counts[k_new] += 1;
218                    theta_counts[d][k_new] += 1;
219                }
220            }
221            iter_done += 1;
222        }
223
224        // ── Compute normalised φ and active topics ────────────────────
225        self.phi = (0..t)
226            .map(|k| {
227                let total = self.topic_counts[k] as f64 + eta_sum;
228                (0..vocab_size)
229                    .map(|w| (self.topic_word_counts[k][w] + eta) / total)
230                    .collect()
231            })
232            .collect();
233
234        self.n_topics_active = self.topic_counts.iter().filter(|&&c| c > 0).count();
235        self.is_fitted = true;
236
237        // ── Compute log-likelihood / perplexity ───────────────────────
238        let (ll, pp) = self.compute_perplexity(corpus, &theta_counts, eta, eta_sum);
239
240        Ok(HdpResult {
241            n_topics: self.n_topics_active,
242            perplexity: pp,
243            log_likelihood: ll,
244            iterations: iter_done,
245        })
246    }
247
248    // ── transform ────────────────────────────────────────────────────────
249
250    /// Infer the topic distribution for a new (unseen) document.
251    ///
252    /// Uses one pass of the E-step (word-topic probability normalisation)
253    /// without modifying model parameters.
254    ///
255    /// # Errors
256    /// Returns an error if the model is not fitted or `doc` is empty.
257    pub fn transform(&self, doc: &[usize]) -> Result<Vec<f64>> {
258        if !self.is_fitted {
259            return Err(TextError::ModelNotFitted(
260                "HDP model not fitted yet".to_string(),
261            ));
262        }
263        if doc.is_empty() {
264            return Err(TextError::InvalidInput(
265                "document must not be empty".to_string(),
266            ));
267        }
268
269        let t = self.config.max_topics;
270        let eta = self.config.eta;
271        let eta_sum = eta * self.vocab_size as f64;
272
273        let mut theta = vec![self.config.alpha / t as f64; t];
274
275        // Simple E-step: accumulate normalised word-topic probabilities
276        for &w in doc {
277            if w >= self.vocab_size {
278                continue;
279            }
280            let mut word_probs: Vec<f64> = (0..t)
281                .map(|k| {
282                    theta[k] * (self.topic_word_counts[k][w] + eta)
283                        / (self.topic_counts[k] as f64 + eta_sum)
284                })
285                .collect();
286            let sum: f64 = word_probs.iter().sum();
287            if sum > 0.0 {
288                word_probs.iter_mut().for_each(|p| *p /= sum);
289                for k in 0..t {
290                    theta[k] += word_probs[k];
291                }
292            }
293        }
294
295        // Normalise θ
296        let theta_sum: f64 = theta.iter().sum();
297        if theta_sum > 0.0 {
298            theta.iter_mut().for_each(|p| *p /= theta_sum);
299        }
300
301        Ok(theta)
302    }
303
304    // ── top_words ────────────────────────────────────────────────────────
305
306    /// Return the top-`n` word indices for each *active* topic, ordered by
307    /// descending probability.
308    ///
309    /// # Errors
310    /// Returns an error if the model is not fitted.
311    pub fn top_words(&self, n: usize) -> Result<Vec<Vec<usize>>> {
312        if !self.is_fitted {
313            return Err(TextError::ModelNotFitted(
314                "HDP model not fitted yet".to_string(),
315            ));
316        }
317
318        let t = self.config.max_topics;
319        let mut result = Vec::new();
320
321        for k in 0..t {
322            if self.topic_counts[k] == 0 {
323                continue; // skip inactive topics
324            }
325            let phi_k = &self.phi[k];
326            let mut indices: Vec<usize> = (0..phi_k.len()).collect();
327            // Sort by descending probability
328            indices.sort_by(|&a, &b| {
329                phi_k[b]
330                    .partial_cmp(&phi_k[a])
331                    .unwrap_or(std::cmp::Ordering::Equal)
332            });
333            indices.truncate(n);
334            result.push(indices);
335        }
336
337        Ok(result)
338    }
339
340    // ── coherence ────────────────────────────────────────────────────────
341
342    /// Compute per-topic PMI-based coherence scores.
343    ///
344    /// Uses the top-`n_top` words per active topic and measures co-occurrence
345    /// in the training corpus.
346    ///
347    /// # Errors
348    /// Returns an error if the model is not fitted.
349    pub fn coherence(&self, corpus: &[Vec<usize>], n_top: usize) -> Result<Vec<f64>> {
350        if !self.is_fitted {
351            return Err(TextError::ModelNotFitted(
352                "HDP model not fitted yet".to_string(),
353            ));
354        }
355
356        let top = self.top_words(n_top)?;
357        let n_docs = corpus.len() as f64;
358
359        // Build word document-frequency map
360        let mut df: Vec<f64> = vec![0.0; self.vocab_size];
361        let mut codf: Vec<Vec<f64>> = vec![vec![0.0; self.vocab_size]; self.vocab_size];
362        for doc in corpus {
363            // Deduplicate within doc for DF counting
364            let mut seen = std::collections::HashSet::new();
365            for &w in doc {
366                if w < self.vocab_size && seen.insert(w) {
367                    df[w] += 1.0;
368                }
369            }
370            let seen_vec: Vec<usize> = seen.into_iter().collect();
371            for (i, &wi) in seen_vec.iter().enumerate() {
372                for &wj in &seen_vec[i + 1..] {
373                    let (a, b) = if wi < wj { (wi, wj) } else { (wj, wi) };
374                    codf[a][b] += 1.0;
375                }
376            }
377        }
378
379        let mut scores = Vec::with_capacity(top.len());
380        for topic_words in &top {
381            let mut sum = 0.0f64;
382            let mut count = 0usize;
383            for (i, &wi) in topic_words.iter().enumerate() {
384                for &wj in &topic_words[i + 1..] {
385                    let (a, b) = if wi < wj { (wi, wj) } else { (wj, wi) };
386                    let co = codf[a][b] + 1.0; // Laplace smoothing
387                    let di = df[wi] + 1.0;
388                    let dj = df[wj] + 1.0;
389                    // PMI = log(P(wi,wj)) - log(P(wi)) - log(P(wj))
390                    let pmi = (co / n_docs).ln() - (di / n_docs).ln() - (dj / n_docs).ln();
391                    sum += pmi;
392                    count += 1;
393                }
394            }
395            scores.push(if count > 0 { sum / count as f64 } else { 0.0 });
396        }
397
398        Ok(scores)
399    }
400
401    // ── Internal helpers ──────────────────────────────────────────────────
402
403    fn make_rng(&self) -> StdRng {
404        match self.config.seed {
405            Some(s) => StdRng::seed_from_u64(s),
406            None => StdRng::from_rng(&mut scirs2_core::random::rng()),
407        }
408    }
409
410    /// Compute average log-likelihood and perplexity after Gibbs.
411    fn compute_perplexity(
412        &self,
413        corpus: &[Vec<usize>],
414        theta_counts: &[Vec<usize>],
415        eta: f64,
416        eta_sum: f64,
417    ) -> (f64, f64) {
418        let t = self.config.max_topics;
419        let alpha = self.config.alpha;
420        let mut total_ll = 0.0f64;
421        let mut total_tokens = 0usize;
422
423        for (d, doc) in corpus.iter().enumerate() {
424            let theta_sum: f64 = theta_counts[d].iter().sum::<usize>() as f64 + alpha;
425            for &w in doc {
426                if w >= self.vocab_size {
427                    continue;
428                }
429                // p(w | doc) = Σ_k θ_{dk} φ_{kw}
430                let p_w: f64 = (0..t)
431                    .map(|k| {
432                        let theta_dk = (theta_counts[d][k] as f64 + alpha / t as f64) / theta_sum;
433                        let phi_kw = (self.topic_word_counts[k][w] + eta)
434                            / (self.topic_counts[k] as f64 + eta_sum);
435                        theta_dk * phi_kw
436                    })
437                    .sum();
438                if p_w > 0.0 {
439                    total_ll += p_w.ln();
440                }
441                total_tokens += 1;
442            }
443        }
444
445        if total_tokens == 0 {
446            return (0.0, 1.0);
447        }
448
449        let avg_ll = total_ll / total_tokens as f64;
450        let perplexity = (-avg_ll).exp();
451        (avg_ll, perplexity)
452    }
453}
454
455impl std::fmt::Debug for HdpModel {
456    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
457        f.debug_struct("HdpModel")
458            .field("max_topics", &self.config.max_topics)
459            .field("n_topics_active", &self.n_topics_active)
460            .field("vocab_size", &self.vocab_size)
461            .field("is_fitted", &self.is_fitted)
462            .finish()
463    }
464}
465
466// ── Internal sampling helper ──────────────────────────────────────────────────
467
468/// Sample a category index from an unnormalised probability vector.
469fn sample_categorical(probs: &[f64], rng: &mut StdRng) -> usize {
470    let total: f64 = probs.iter().sum();
471    if total <= 0.0 {
472        // Degenerate — return uniform
473        return rng.random_range(0..probs.len());
474    }
475    let u: f64 = rng.random_range(0.0..total);
476    let mut cumulative = 0.0;
477    for (i, &p) in probs.iter().enumerate() {
478        cumulative += p;
479        if u < cumulative {
480            return i;
481        }
482    }
483    probs.len() - 1 // fallback due to floating-point rounding
484}
485
486// ── Tests ─────────────────────────────────────────────────────────────────────
487
488#[cfg(test)]
489mod tests {
490    use super::*;
491
492    /// Build a synthetic corpus with three clearly separated topics.
493    /// Topic 0: words 0–4 (tech)
494    /// Topic 1: words 5–9 (sports)
495    /// Topic 2: words 10–14 (food)
496    fn synthetic_corpus(n_per_topic: usize) -> Vec<Vec<usize>> {
497        let mut corpus = Vec::new();
498        let mut rng = StdRng::seed_from_u64(99);
499        // Topic 0 documents
500        for _ in 0..n_per_topic {
501            let doc: Vec<usize> = (0..20).map(|_| rng.random_range(0..5)).collect();
502            corpus.push(doc);
503        }
504        // Topic 1 documents
505        for _ in 0..n_per_topic {
506            let doc: Vec<usize> = (0..20).map(|_| rng.random_range(5..10)).collect();
507            corpus.push(doc);
508        }
509        // Topic 2 documents
510        for _ in 0..n_per_topic {
511            let doc: Vec<usize> = (0..20).map(|_| rng.random_range(10..15)).collect();
512            corpus.push(doc);
513        }
514        corpus
515    }
516
517    // ── test_hdp_infers_topics ────────────────────────────────────────────
518
519    #[test]
520    fn test_hdp_infers_topics() {
521        let corpus = synthetic_corpus(15);
522        let config = HdpConfig {
523            max_topics: 20,
524            n_iter: 30,
525            seed: Some(42),
526            ..Default::default()
527        };
528        let mut model = HdpModel::new(config);
529        let result = model.fit(&corpus, 15).expect("fit should succeed");
530
531        assert!(
532            result.n_topics <= 20,
533            "active topics ({}) must be <= max_topics",
534            result.n_topics
535        );
536        assert!(
537            result.n_topics >= 1,
538            "at least one topic must be active, got {}",
539            result.n_topics
540        );
541    }
542
543    // ── test_hdp_perplexity_finite ────────────────────────────────────────
544
545    #[test]
546    fn test_hdp_perplexity_finite() {
547        let corpus = synthetic_corpus(10);
548        let config = HdpConfig {
549            max_topics: 10,
550            n_iter: 20,
551            seed: Some(7),
552            ..Default::default()
553        };
554        let mut model = HdpModel::new(config);
555        let result = model.fit(&corpus, 15).expect("fit should succeed");
556
557        assert!(
558            result.perplexity.is_finite(),
559            "perplexity must be finite, got {}",
560            result.perplexity
561        );
562        assert!(
563            result.perplexity > 0.0,
564            "perplexity must be positive, got {}",
565            result.perplexity
566        );
567        assert!(
568            result.log_likelihood.is_finite(),
569            "log_likelihood must be finite"
570        );
571    }
572
573    // ── test_hdp_top_words_valid ──────────────────────────────────────────
574
575    #[test]
576    fn test_hdp_top_words_valid() {
577        let corpus = synthetic_corpus(10);
578        let config = HdpConfig {
579            max_topics: 10,
580            n_iter: 20,
581            seed: Some(1),
582            ..Default::default()
583        };
584        let mut model = HdpModel::new(config);
585        model.fit(&corpus, 15).expect("fit should succeed");
586
587        let top5 = model.top_words(5).expect("top_words should succeed");
588        for topic_words in &top5 {
589            assert!(
590                topic_words.len() <= 5,
591                "each topic should have <= n top words"
592            );
593            for &w in topic_words {
594                assert!(w < 15, "word index {w} must be < vocab_size 15");
595            }
596        }
597    }
598
599    #[test]
600    fn test_hdp_transform() {
601        let corpus = synthetic_corpus(10);
602        let config = HdpConfig {
603            max_topics: 5,
604            n_iter: 15,
605            seed: Some(123),
606            ..Default::default()
607        };
608        let mut model = HdpModel::new(config);
609        model.fit(&corpus, 15).expect("fit should succeed");
610
611        let doc = vec![0usize, 1, 2, 3, 0];
612        let theta = model.transform(&doc).expect("transform should succeed");
613        assert_eq!(theta.len(), 5);
614        let sum: f64 = theta.iter().sum();
615        assert!((sum - 1.0).abs() < 1e-9, "topic distribution must sum to 1");
616        for &p in &theta {
617            assert!(p >= 0.0, "all topic probabilities must be >= 0");
618        }
619    }
620
621    #[test]
622    fn test_hdp_coherence() {
623        let corpus = synthetic_corpus(10);
624        let config = HdpConfig {
625            max_topics: 5,
626            n_iter: 15,
627            seed: Some(55),
628            ..Default::default()
629        };
630        let mut model = HdpModel::new(config);
631        model.fit(&corpus, 15).expect("fit should succeed");
632
633        let scores = model
634            .coherence(&corpus, 3)
635            .expect("coherence should succeed");
636        for &s in &scores {
637            assert!(s.is_finite(), "coherence score must be finite, got {s}");
638        }
639    }
640
641    #[test]
642    fn test_hdp_empty_corpus_error() {
643        let mut model = HdpModel::new(HdpConfig::default());
644        let result = model.fit(&[], 10);
645        assert!(result.is_err());
646    }
647
648    #[test]
649    fn test_hdp_zero_vocab_error() {
650        let mut model = HdpModel::new(HdpConfig::default());
651        let result = model.fit(&[vec![0usize, 1]], 0);
652        assert!(result.is_err());
653    }
654}