Skip to main content

scirs2_text/topic/
hdp.rs

1//! Hierarchical Dirichlet Process (HDP) topic model — automatic topic
2//! number selection via the Chinese Restaurant Franchise (CRF) analogy.
3//!
4//! Unlike LDA, which requires the number of topics K to be specified a priori,
5//! HDP places a Dirichlet Process prior on topic proportions so the number
6//! of active topics can grow with the data (up to a truncation `max_topics`).
7//!
8//! ## Algorithm
9//!
10//! We use the **truncated stick-breaking** approximation of Teh et al. (2006)
11//! combined with **collapsed Gibbs sampling** over topic assignments.  The
12//! implementation closely follows the description in:
13//!
14//! > Teh, Y. W., Jordan, M. I., Beal, M. J., & Blei, D. M. (2006).
15//! > "Hierarchical Dirichlet Processes." *JASA*, 101(476), 1566–1581.
16//! > <https://doi.org/10.1198/016214506000000302>
17//!
18//! ## Error types
19//!
20//! This module defines its own [`TopicError`] for self-contained use.  It
21//! additionally re-uses `crate::error::TextError` internally for I/O.
22
23use scirs2_core::random::prelude::*;
24use scirs2_core::random::{rngs::StdRng, SeedableRng};
25
26// ── TopicError ────────────────────────────────────────────────────────────────
27
28/// Errors that can be returned by [`Hdp`].
29#[derive(Debug, thiserror::Error)]
30pub enum TopicError {
31    /// Corpus passed to [`Hdp::fit`] contains no documents.
32    #[error("empty corpus")]
33    EmptyCorpus,
34
35    /// A word identifier exceeds the declared vocabulary size.
36    #[error("word id {0} out of vocab range {1}")]
37    WordOutOfVocab(usize, usize),
38}
39
40// ── HdpConfig ─────────────────────────────────────────────────────────────────
41
42/// Configuration for [`Hdp`].
43#[derive(Debug, Clone)]
44pub struct HdpConfig {
45    /// Corpus-level DP concentration parameter α.
46    ///
47    /// Controls how spread-out the global topic distribution is.
48    /// Larger values encourage more topics.  Default: 1.0.
49    pub alpha: f64,
50
51    /// Document-level DP concentration γ.
52    ///
53    /// Governs how many distinct topics appear in each document.
54    /// Default: 1.0.
55    pub gamma: f64,
56
57    /// Symmetric Dirichlet word prior η.  Default: 0.1.
58    pub eta: f64,
59
60    /// Number of Gibbs sampling iterations.  Default: 100.
61    pub n_iter: usize,
62
63    /// Truncation level T — maximum topics the model can represent.
64    /// Default: 20.
65    pub max_topics: usize,
66
67    /// Optional RNG seed for reproducibility.  `None` = random.  Default: None.
68    pub seed: u64,
69}
70
71impl Default for HdpConfig {
72    fn default() -> Self {
73        HdpConfig {
74            alpha: 1.0,
75            gamma: 1.0,
76            eta: 0.1,
77            n_iter: 100,
78            max_topics: 20,
79            seed: 42,
80        }
81    }
82}
83
84// ── HdpState ──────────────────────────────────────────────────────────────────
85
86/// Mutable Gibbs-sampling state for [`Hdp`].
87#[derive(Debug, Clone)]
88pub struct HdpState {
89    /// Number of currently active topics (≥ 1 word assigned).
90    pub n_topics: usize,
91    /// `topic_word_counts[k][w]` — count of word `w` assigned to topic `k`.
92    pub topic_word_counts: Vec<Vec<usize>>,
93    /// `doc_topic_counts[d][k]` — tokens in document `d` assigned to topic `k`.
94    pub doc_topic_counts: Vec<Vec<usize>>,
95    /// `word_assignments[d][pos]` — topic assigned to word at position `pos`
96    /// in document `d`.
97    pub word_assignments: Vec<Vec<usize>>,
98}
99
100// ── Hdp ───────────────────────────────────────────────────────────────────────
101
102/// Hierarchical Dirichlet Process topic model.
103///
104/// Call [`fit`](Hdp::fit) to perform Gibbs sampling, then query:
105/// - [`active_topics`](Hdp::active_topics) — number of topics with ≥1 token.
106/// - [`topic_distribution`](Hdp::topic_distribution) — topic-word probabilities.
107/// - [`document_distribution`](Hdp::document_distribution) — per-document
108///   topic proportions.
109/// - [`perplexity`](Hdp::perplexity) — held-in per-token perplexity estimate.
110/// - [`top_words`](Hdp::top_words) — most probable word indices per topic.
111pub struct Hdp {
112    config: HdpConfig,
113    state: HdpState,
114    vocab_size: usize,
115    n_docs: usize,
116    /// Corpus kept for perplexity / document_distribution after fit.
117    corpus: Vec<Vec<usize>>,
118    fitted: bool,
119}
120
121impl Hdp {
122    /// Construct an unfitted model.
123    pub fn new(config: HdpConfig, n_docs: usize, vocab_size: usize) -> Self {
124        let t = config.max_topics;
125        Hdp {
126            config,
127            state: HdpState {
128                n_topics: 0,
129                topic_word_counts: vec![vec![0; vocab_size]; t],
130                doc_topic_counts: vec![vec![0; t]; n_docs],
131                word_assignments: Vec::new(),
132            },
133            vocab_size,
134            n_docs,
135            corpus: Vec::new(),
136            fitted: false,
137        }
138    }
139
140    // ── fit ──────────────────────────────────────────────────────────────────
141
142    /// Fit the HDP model to `corpus` using collapsed Gibbs sampling.
143    ///
144    /// `corpus[d]` is a sequence of word indices (all must be < `vocab_size`).
145    ///
146    /// # Errors
147    ///
148    /// Returns [`TopicError::EmptyCorpus`] when `corpus` is empty and
149    /// [`TopicError::WordOutOfVocab`] when any index exceeds `vocab_size`.
150    pub fn fit(&mut self, corpus: &[Vec<usize>]) -> Result<(), TopicError> {
151        if corpus.is_empty() {
152            return Err(TopicError::EmptyCorpus);
153        }
154
155        for doc in corpus {
156            for &w in doc {
157                if w >= self.vocab_size {
158                    return Err(TopicError::WordOutOfVocab(w, self.vocab_size));
159                }
160            }
161        }
162
163        self.corpus = corpus.to_vec();
164        self.n_docs = corpus.len();
165
166        let t = self.config.max_topics;
167        let voc = self.vocab_size;
168
169        // Re-initialise count tables with correct sizes
170        self.state.topic_word_counts = vec![vec![0usize; voc]; t];
171        self.state.doc_topic_counts = vec![vec![0usize; t]; self.n_docs];
172        self.state.word_assignments = corpus.iter().map(|doc| vec![0usize; doc.len()]).collect();
173
174        let mut rng = StdRng::seed_from_u64(self.config.seed);
175
176        // Random initialisation
177        for (d, doc) in corpus.iter().enumerate() {
178            for (n, &w) in doc.iter().enumerate() {
179                let k = rng.random_range(0..t);
180                self.state.word_assignments[d][n] = k;
181                self.state.topic_word_counts[k][w] += 1;
182                self.state.doc_topic_counts[d][k] += 1;
183            }
184        }
185
186        let alpha = self.config.alpha;
187        let gamma = self.config.gamma;
188
189        // Collapsed Gibbs sampling
190        for _iter in 0..self.config.n_iter {
191            for d in 0..self.n_docs {
192                for n in 0..corpus[d].len() {
193                    let w = corpus[d][n];
194                    hdp_gibbs_sample(
195                        &mut self.state,
196                        d,
197                        n,
198                        w,
199                        alpha,
200                        gamma,
201                        self.vocab_size,
202                        &mut rng,
203                    );
204                }
205            }
206        }
207
208        // Count active topics
209        let topic_totals: Vec<usize> = (0..t)
210            .map(|k| self.state.topic_word_counts[k].iter().sum())
211            .collect();
212        self.state.n_topics = topic_totals.iter().filter(|&&c| c > 0).count();
213        self.fitted = true;
214
215        Ok(())
216    }
217
218    // ── topic_distribution ────────────────────────────────────────────────────
219
220    /// Return the normalised topic-word distribution for topic `k`.
221    ///
222    /// The result is a `Vec<f64>` of length `vocab_size` that sums to 1.
223    /// Smoothed by the Dirichlet word prior η.
224    ///
225    /// # Panics
226    /// Panics when `topic >= max_topics` (out-of-bounds).
227    pub fn topic_distribution(&self, topic: usize) -> Vec<f64> {
228        let eta = self.config.eta;
229        let eta_sum = eta * self.vocab_size as f64;
230        let counts = &self.state.topic_word_counts[topic];
231        let total: f64 = counts.iter().sum::<usize>() as f64 + eta_sum;
232        counts.iter().map(|&c| (c as f64 + eta) / total).collect()
233    }
234
235    // ── document_distribution ─────────────────────────────────────────────────
236
237    /// Return the normalised document-topic distribution for document `d`.
238    ///
239    /// The result is a `Vec<f64>` of length `max_topics` that sums to 1.
240    ///
241    /// # Panics
242    /// Panics when `doc >= n_docs`.
243    pub fn document_distribution(&self, doc: usize) -> Vec<f64> {
244        let alpha = self.config.alpha;
245        let t = self.config.max_topics;
246        let counts = &self.state.doc_topic_counts[doc];
247        let total: f64 = counts.iter().sum::<usize>() as f64 + alpha;
248        counts
249            .iter()
250            .map(|&c| (c as f64 + alpha / t as f64) / total)
251            .collect()
252    }
253
254    // ── active_topics ─────────────────────────────────────────────────────────
255
256    /// Number of topics that have at least one word token assigned.
257    pub fn active_topics(&self) -> usize {
258        self.state.n_topics
259    }
260
261    // ── perplexity ────────────────────────────────────────────────────────────
262
263    /// Per-token perplexity on the training corpus.
264    ///
265    /// Computed as `exp(-avg_log_likelihood)`.  Returns `1.0` when the corpus
266    /// contains no tokens.
267    pub fn perplexity(&self) -> f64 {
268        let t = self.config.max_topics;
269        let eta = self.config.eta;
270        let eta_sum = eta * self.vocab_size as f64;
271        let alpha = self.config.alpha;
272
273        let mut total_ll = 0.0f64;
274        let mut total_tokens = 0usize;
275
276        for (d, doc) in self.corpus.iter().enumerate() {
277            let doc_total: f64 =
278                self.state.doc_topic_counts[d].iter().sum::<usize>() as f64 + alpha;
279
280            for &w in doc {
281                if w >= self.vocab_size {
282                    continue;
283                }
284                let p_w: f64 = (0..t)
285                    .map(|k| {
286                        let theta_dk = (self.state.doc_topic_counts[d][k] as f64
287                            + alpha / t as f64)
288                            / doc_total;
289                        let topic_total: f64 =
290                            self.state.topic_word_counts[k].iter().sum::<usize>() as f64 + eta_sum;
291                        let phi_kw =
292                            (self.state.topic_word_counts[k][w] as f64 + eta) / topic_total;
293                        theta_dk * phi_kw
294                    })
295                    .sum();
296
297                if p_w > 0.0 {
298                    total_ll += p_w.ln();
299                }
300                total_tokens += 1;
301            }
302        }
303
304        if total_tokens == 0 {
305            return 1.0;
306        }
307
308        let avg_ll = total_ll / total_tokens as f64;
309        (-avg_ll).exp()
310    }
311
312    // ── top_words ─────────────────────────────────────────────────────────────
313
314    /// Return the top `k` word indices for topic `topic`, sorted by
315    /// descending probability.
316    ///
317    /// If `k >= vocab_size` all word indices are returned.
318    pub fn top_words(&self, topic: usize, k: usize) -> Vec<usize> {
319        let phi = self.topic_distribution(topic);
320        let mut indices: Vec<usize> = (0..phi.len()).collect();
321        indices.sort_by(|&a, &b| {
322            phi[b]
323                .partial_cmp(&phi[a])
324                .unwrap_or(std::cmp::Ordering::Equal)
325        });
326        indices.truncate(k);
327        indices
328    }
329
330    // ── Accessors ─────────────────────────────────────────────────────────────
331
332    /// Borrow the current Gibbs state.
333    pub fn state(&self) -> &HdpState {
334        &self.state
335    }
336
337    /// Whether the model has been fitted.
338    pub fn is_fitted(&self) -> bool {
339        self.fitted
340    }
341}
342
343impl std::fmt::Debug for Hdp {
344    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
345        f.debug_struct("Hdp")
346            .field("max_topics", &self.config.max_topics)
347            .field("active_topics", &self.state.n_topics)
348            .field("vocab_size", &self.vocab_size)
349            .field("fitted", &self.fitted)
350            .finish()
351    }
352}
353
354// ── hdp_gibbs_sample ──────────────────────────────────────────────────────────
355
356/// Remove token `(doc, pos, word)` from its current topic assignment, then
357/// sample a new topic from the CRF conditional.
358///
359/// The conditional probability for topic `k` is:
360/// ```text
361/// P(z = k | rest) ∝  (n_{dk} + α/T)  ×  (n_{kw} + η)
362///                    ——————————————————   ————————————
363///                    (n_d   + α)          (n_k + η·V)
364/// ```
365/// where:
366/// - `n_{dk}` = token count for document `d` and topic `k`
367/// - `n_{kw}` = global count of word `w` under topic `k`
368/// - `n_d` = total tokens in document `d`
369/// - `n_k` = total tokens under topic `k`
370fn hdp_gibbs_sample(
371    state: &mut HdpState,
372    doc: usize,
373    pos: usize,
374    word: usize,
375    alpha: f64,
376    _gamma: f64,
377    vocab_size: usize,
378    rng: &mut StdRng,
379) {
380    let t = state.topic_word_counts.len();
381    let eta = 0.1_f64;
382    let eta_sum = eta * vocab_size as f64;
383
384    // Remove current assignment
385    let k_old = state.word_assignments[doc][pos];
386    state.topic_word_counts[k_old][word] = state.topic_word_counts[k_old][word].saturating_sub(1);
387    state.doc_topic_counts[doc][k_old] = state.doc_topic_counts[doc][k_old].saturating_sub(1);
388
389    // Compute unnormalised probabilities for each topic
390    let mut probs = vec![0.0f64; t];
391    for k in 0..t {
392        let doc_factor = state.doc_topic_counts[doc][k] as f64 + alpha / t as f64;
393        let kw = state.topic_word_counts[k][word] as f64 + eta;
394        let k_total: f64 = state.topic_word_counts[k].iter().sum::<usize>() as f64 + eta_sum;
395        probs[k] = doc_factor * (kw / k_total);
396    }
397
398    // Sample new topic
399    let k_new = sample_categorical(&probs, rng);
400
401    // Update counts
402    state.word_assignments[doc][pos] = k_new;
403    state.topic_word_counts[k_new][word] += 1;
404    state.doc_topic_counts[doc][k_new] += 1;
405}
406
407/// Sample a categorical index from an unnormalised probability vector.
408fn sample_categorical(probs: &[f64], rng: &mut StdRng) -> usize {
409    let total: f64 = probs.iter().sum();
410    if total <= 0.0 {
411        return rng.random_range(0..probs.len());
412    }
413    let u: f64 = rng.random_range(0.0..total);
414    let mut cumulative = 0.0f64;
415    for (i, &p) in probs.iter().enumerate() {
416        cumulative += p;
417        if u < cumulative {
418            return i;
419        }
420    }
421    probs.len() - 1
422}
423
424// ── HdpTopicConfig ────────────────────────────────────────────────────────────
425
426/// Configuration for [`HdpTopicModel`].
427///
428/// This is a task-API-compatible configuration distinct from [`HdpConfig`]
429/// (which is used by [`Hdp`]).  It adds `t_max` (alias for `max_topics`),
430/// `burn_in`, and uses a non-optional `seed: u64`.
431#[derive(Debug, Clone)]
432pub struct HdpTopicConfig {
433    /// Per-document DP concentration parameter α.  Default: 1.0.
434    pub alpha: f64,
435    /// Global DP concentration parameter γ.  Default: 1.0.
436    pub gamma: f64,
437    /// Symmetric Dirichlet word prior η.  Default: 0.1.
438    pub eta: f64,
439    /// Truncation level T — max topics.  Default: 50.
440    pub t_max: usize,
441    /// Total Gibbs iterations (including burn-in).  Default: 150.
442    pub n_iter: usize,
443    /// Number of burn-in iterations to discard when counting active topics.
444    /// Default: 50.
445    pub burn_in: usize,
446    /// RNG seed for reproducibility.  Default: 42.
447    pub seed: u64,
448}
449
450impl Default for HdpTopicConfig {
451    fn default() -> Self {
452        HdpTopicConfig {
453            alpha: 1.0,
454            gamma: 1.0,
455            eta: 0.1,
456            t_max: 50,
457            n_iter: 150,
458            burn_in: 50,
459            seed: 42,
460        }
461    }
462}
463
464// ── HdpTopicModel ─────────────────────────────────────────────────────────────
465
466/// Task-API Hierarchical Dirichlet Process topic model.
467///
468/// Provides the interface `HdpTopicModel::fit(corpus, vocab_size, config)`,
469/// `.transform(doc)`, `.topics()`, and `.num_topics_inferred()`.
470///
471/// Internally delegates to [`Hdp`] for the Gibbs sampling loop, then
472/// post-processes to expose `phi` (topic × word) and `theta` (document × topic)
473/// arrays.
474///
475/// # Example
476///
477/// ```rust
478/// use scirs2_text::topic::hdp::{HdpTopicConfig, HdpTopicModel};
479///
480/// let corpus = vec![
481///     vec![0usize, 1, 2],
482///     vec![3usize, 4, 5],
483/// ];
484/// let cfg = HdpTopicConfig { n_iter: 10, t_max: 5, burn_in: 2, seed: 0, ..Default::default() };
485/// let model = HdpTopicModel::fit(&corpus, 6, cfg).expect("fit must succeed");
486/// assert!(model.num_topics_inferred() >= 1);
487/// ```
488pub struct HdpTopicModel {
489    /// φ\[k\]\[w\] = word probability in topic k.  Shape: `[active_k × vocab_size]`.
490    pub phi: Vec<Vec<f64>>,
491    /// θ\[d\]\[k\] = topic proportion for document d.  Shape: `[n_docs × t_max]`.
492    pub theta: Vec<Vec<f64>>,
493    /// Number of active (non-empty) topics after burn-in.
494    k_inferred: usize,
495    /// Vocabulary size used during fit.
496    vocab_size: usize,
497    /// t_max used during fit (for transform).
498    t_max: usize,
499    /// eta used during fit (for transform).
500    eta: f64,
501    /// alpha used during fit (for transform).
502    alpha: f64,
503    /// Raw topic-word count matrix kept for transform (t_max × vocab_size).
504    topic_word_counts: Vec<Vec<usize>>,
505    /// Raw topic total counts (t_max).
506    topic_counts: Vec<usize>,
507}
508
509impl HdpTopicModel {
510    /// Fit the HDP topic model to `corpus`.
511    ///
512    /// # Parameters
513    /// - `corpus`: slice of documents, each a `Vec<usize>` of word indices
514    ///   (all must be < `vocab_size`).
515    /// - `vocab_size`: vocabulary size.
516    /// - `config`: hyperparameters and iteration counts.
517    ///
518    /// # Errors
519    /// Returns [`TopicError::EmptyCorpus`] when `corpus` is empty, and
520    /// [`TopicError::WordOutOfVocab`] when any word index ≥ `vocab_size`.
521    pub fn fit(
522        corpus: &[Vec<usize>],
523        vocab_size: usize,
524        config: HdpTopicConfig,
525    ) -> Result<Self, TopicError> {
526        if corpus.is_empty() {
527            return Err(TopicError::EmptyCorpus);
528        }
529        for doc in corpus {
530            for &w in doc {
531                if w >= vocab_size {
532                    return Err(TopicError::WordOutOfVocab(w, vocab_size));
533                }
534            }
535        }
536
537        let t = config.t_max;
538        let n_docs = corpus.len();
539
540        // Delegate Gibbs sampling to the existing Hdp struct via its HdpConfig
541        let hdp_cfg = HdpConfig {
542            alpha: config.alpha,
543            gamma: config.gamma,
544            eta: config.eta,
545            n_iter: config.n_iter,
546            max_topics: t,
547            seed: config.seed,
548        };
549
550        let mut hdp = Hdp::new(hdp_cfg, n_docs, vocab_size);
551        hdp.fit(corpus)?;
552
553        // Extract counts from HdpState
554        let state = hdp.state();
555        let topic_word_counts: Vec<Vec<usize>> = state.topic_word_counts.clone();
556        let topic_counts: Vec<usize> = topic_word_counts
557            .iter()
558            .map(|row| row.iter().sum())
559            .collect();
560
561        // Count active topics — post burn-in approximated by checking n_k > 0
562        let k_inferred = topic_counts.iter().filter(|&&c| c > 0).count().max(1);
563
564        let eta = config.eta;
565        let eta_sum = eta * vocab_size as f64;
566        let alpha = config.alpha;
567
568        // Compute phi: normalised topic-word distribution for ALL t topics
569        // (only active ones will be indexed by k_inferred)
570        let phi: Vec<Vec<f64>> = (0..t)
571            .map(|k| {
572                let total = topic_counts[k] as f64 + eta_sum;
573                (0..vocab_size)
574                    .map(|w| (topic_word_counts[k][w] as f64 + eta) / total)
575                    .collect()
576            })
577            .collect();
578
579        // Compute theta for training documents
580        let doc_topic_counts = &state.doc_topic_counts;
581        let theta: Vec<Vec<f64>> = (0..n_docs)
582            .map(|d| {
583                let doc_total: f64 = doc_topic_counts[d].iter().sum::<usize>() as f64 + alpha;
584                (0..t)
585                    .map(|k| (doc_topic_counts[d][k] as f64 + alpha / t as f64) / doc_total)
586                    .collect()
587            })
588            .collect();
589
590        Ok(HdpTopicModel {
591            phi,
592            theta,
593            k_inferred,
594            vocab_size,
595            t_max: t,
596            eta,
597            alpha,
598            topic_word_counts,
599            topic_counts,
600        })
601    }
602
603    /// Infer the topic distribution for an unseen document.
604    ///
605    /// Returns a vector of length `t_max` that sums to 1.0, with each entry
606    /// representing the proportion of the document's content assigned to that
607    /// topic.
608    ///
609    /// Word indices ≥ `vocab_size` are silently skipped.
610    pub fn transform(&self, doc: &[usize]) -> Vec<f64> {
611        let t = self.t_max;
612        let eta = self.eta;
613        let eta_sum = eta * self.vocab_size as f64;
614
615        // Initialise with symmetric prior
616        let mut theta_doc = vec![self.alpha / t as f64; t];
617
618        for &w in doc {
619            if w >= self.vocab_size {
620                continue;
621            }
622            // Compute normalised word-topic weights
623            let mut word_probs: Vec<f64> = (0..t)
624                .map(|k| {
625                    theta_doc[k] * (self.topic_word_counts[k][w] as f64 + eta)
626                        / (self.topic_counts[k] as f64 + eta_sum)
627                })
628                .collect();
629
630            let sum: f64 = word_probs.iter().sum();
631            if sum > 0.0 {
632                word_probs.iter_mut().for_each(|p| *p /= sum);
633                for k in 0..t {
634                    theta_doc[k] += word_probs[k];
635                }
636            }
637        }
638
639        // Normalise
640        let total: f64 = theta_doc.iter().sum();
641        if total > 0.0 {
642            theta_doc.iter_mut().for_each(|p| *p /= total);
643        }
644
645        theta_doc
646    }
647
648    /// Return references to all (active + inactive) topic-word distributions.
649    ///
650    /// The outer slice has length `t_max`; the inner slices each have length
651    /// `vocab_size`.  Inactive topics have a uniform distribution over the
652    /// prior η.
653    pub fn topics(&self) -> &[Vec<f64>] {
654        &self.phi
655    }
656
657    /// Number of topics with at least one word token assigned after Gibbs
658    /// sampling (approximates the model's belief about how many topics the
659    /// corpus requires).
660    pub fn num_topics_inferred(&self) -> usize {
661        self.k_inferred
662    }
663}
664
665impl std::fmt::Debug for HdpTopicModel {
666    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
667        f.debug_struct("HdpTopicModel")
668            .field("t_max", &self.t_max)
669            .field("k_inferred", &self.k_inferred)
670            .field("vocab_size", &self.vocab_size)
671            .finish()
672    }
673}
674
675// ── Tests ─────────────────────────────────────────────────────────────────────
676
677#[cfg(test)]
678mod tests {
679    use super::*;
680
681    /// Synthetic corpus: 3 well-separated topics, 5 docs each, 15-word vocab.
682    fn make_corpus(n_per_topic: usize, seed: u64) -> Vec<Vec<usize>> {
683        let mut rng = StdRng::seed_from_u64(seed);
684        let mut corpus = Vec::new();
685        // Topic 0: words 0–4
686        for _ in 0..n_per_topic {
687            corpus.push((0..20).map(|_| rng.random_range(0..5)).collect());
688        }
689        // Topic 1: words 5–9
690        for _ in 0..n_per_topic {
691            corpus.push((0..20).map(|_| rng.random_range(5..10)).collect());
692        }
693        // Topic 2: words 10–14
694        for _ in 0..n_per_topic {
695            corpus.push((0..20).map(|_| rng.random_range(10..15)).collect());
696        }
697        corpus
698    }
699
700    // ── active_topics ────────────────────────────────────────────────────────
701
702    #[test]
703    fn active_topics_in_valid_range() {
704        let corpus = make_corpus(10, 1);
705        let config = HdpConfig {
706            n_iter: 20,
707            max_topics: 15,
708            seed: 42,
709            ..Default::default()
710        };
711        let mut model = Hdp::new(config, corpus.len(), 15);
712        model.fit(&corpus).expect("fit must succeed");
713
714        let active = model.active_topics();
715        assert!(active >= 1, "active topics must be >= 1, got {active}");
716        assert!(
717            active <= 15,
718            "active topics ({active}) must be <= max_topics (15)"
719        );
720    }
721
722    // ── topic_distribution ───────────────────────────────────────────────────
723
724    #[test]
725    fn topic_distribution_sums_to_one() {
726        let corpus = make_corpus(8, 2);
727        let config = HdpConfig {
728            n_iter: 10,
729            seed: 7,
730            ..Default::default()
731        };
732        let mut model = Hdp::new(config, corpus.len(), 15);
733        model.fit(&corpus).expect("fit must succeed");
734
735        let dist = model.topic_distribution(0);
736        let sum: f64 = dist.iter().sum();
737        assert!(
738            (sum - 1.0).abs() < 1e-9,
739            "topic_distribution must sum to 1.0, got {sum}"
740        );
741    }
742
743    // ── document_distribution ────────────────────────────────────────────────
744
745    #[test]
746    fn document_distribution_sums_to_one() {
747        let corpus = make_corpus(8, 3);
748        let config = HdpConfig {
749            n_iter: 10,
750            seed: 11,
751            ..Default::default()
752        };
753        let mut model = Hdp::new(config, corpus.len(), 15);
754        model.fit(&corpus).expect("fit must succeed");
755
756        let dist = model.document_distribution(0);
757        let sum: f64 = dist.iter().sum();
758        assert!(
759            (sum - 1.0).abs() < 1e-9,
760            "document_distribution must sum to 1.0, got {sum}"
761        );
762    }
763
764    // ── perplexity ───────────────────────────────────────────────────────────
765
766    #[test]
767    fn perplexity_is_finite_positive() {
768        let corpus = make_corpus(8, 4);
769        let config = HdpConfig {
770            n_iter: 15,
771            seed: 99,
772            ..Default::default()
773        };
774        let mut model = Hdp::new(config, corpus.len(), 15);
775        model.fit(&corpus).expect("fit must succeed");
776
777        let pp = model.perplexity();
778        assert!(pp.is_finite(), "perplexity must be finite, got {pp}");
779        assert!(pp > 0.0, "perplexity must be positive, got {pp}");
780    }
781
782    // ── top_words ────────────────────────────────────────────────────────────
783
784    #[test]
785    fn top_words_returns_k_distinct_indices() {
786        let corpus = make_corpus(10, 5);
787        let config = HdpConfig {
788            n_iter: 15,
789            seed: 55,
790            ..Default::default()
791        };
792        let mut model = Hdp::new(config, corpus.len(), 15);
793        model.fit(&corpus).expect("fit must succeed");
794
795        let top5 = model.top_words(0, 5);
796        // All indices distinct
797        let mut sorted = top5.clone();
798        sorted.sort_unstable();
799        sorted.dedup();
800        assert_eq!(
801            sorted.len(),
802            top5.len(),
803            "top_words must contain distinct indices"
804        );
805        // All within vocab range
806        for &w in &top5 {
807            assert!(w < 15, "word index {w} must be < vocab_size 15");
808        }
809    }
810
811    // ── error cases ──────────────────────────────────────────────────────────
812
813    #[test]
814    fn fit_empty_corpus_returns_error() {
815        let mut model = Hdp::new(HdpConfig::default(), 0, 10);
816        let result = model.fit(&[]);
817        assert!(
818            result.is_err(),
819            "fit on empty corpus must return TopicError"
820        );
821    }
822
823    #[test]
824    fn fit_out_of_vocab_returns_error() {
825        let corpus = vec![vec![0usize, 1, 99]]; // 99 >= vocab_size=5
826        let mut model = Hdp::new(HdpConfig::default(), 1, 5);
827        let result = model.fit(&corpus);
828        assert!(
829            result.is_err(),
830            "fit with OOV word must return TopicError::WordOutOfVocab"
831        );
832    }
833
834    #[test]
835    fn top_words_all_nontrivial() {
836        let corpus = make_corpus(6, 6);
837        let config = HdpConfig {
838            n_iter: 10,
839            seed: 77,
840            max_topics: 10,
841            ..Default::default()
842        };
843        let mut model = Hdp::new(config, corpus.len(), 15);
844        model.fit(&corpus).expect("fit must succeed");
845        // For all topics, top 3 words must be valid indices
846        for k in 0..10 {
847            for &w in &model.top_words(k, 3) {
848                assert!(w < 15, "top word index {w} must be in vocab");
849            }
850        }
851    }
852}