Skip to main content

scirs2_text/ctm/
mod.rs

1//! # Correlated Topic Model (CTM)
2//!
3//! Implements the Correlated Topic Model of Blei & Lafferty (2006), which uses
4//! a logistic-normal prior over topic proportions to capture inter-topic
5//! correlations that a plain Dirichlet prior cannot represent.
6//!
7//! ## Overview
8//!
9//! Unlike LDA, CTM models the document-topic distribution as:
10//!
11//! ```text
12//!   η_d  ~ N(µ, Σ)
13//!   θ_d  =  softmax(η_d)
14//!   w_dn ~ Multinomial(θ_d, β)
15//! ```
16//!
17//! Inference is performed via mean-field variational EM.
18//!
19//! ## Example
20//!
21//! ```rust
22//! use scirs2_text::ctm::{CorrelatedTopicModel, CtmConfig};
23//!
24//! let config = CtmConfig {
25//!     n_topics: 3,
26//!     max_iter: 20,
27//!     tol: 1e-4,
28//!     vocab_size: 10,
29//! };
30//! let model = CorrelatedTopicModel::new(config);
31//!
32//! // Documents as word-count vectors (length = vocab_size)
33//! let docs: Vec<Vec<f64>> = (0..5)
34//!     .map(|i| (0..10).map(|w| ((i * 3 + w) % 4) as f64).collect())
35//!     .collect();
36//!
37//! let result = model.fit(&docs, 10).expect("CTM fit failed");
38//! assert_eq!(result.topic_word_matrix.len(), 3);
39//! ```
40
41pub mod inference;
42pub mod model;
43
44use crate::error::Result;
45
46// ────────────────────────────────────────────────────────────────────────────
47// Public re-exports
48// ────────────────────────────────────────────────────────────────────────────
49
50pub use model::{log_likelihood, softmax, top_words, topic_correlation_matrix};
51
52// ────────────────────────────────────────────────────────────────────────────
53// Configuration
54// ────────────────────────────────────────────────────────────────────────────
55
56/// Configuration for the Correlated Topic Model.
57#[derive(Debug, Clone)]
58pub struct CtmConfig {
59    /// Number of latent topics.
60    pub n_topics: usize,
61    /// Maximum number of EM iterations.
62    pub max_iter: usize,
63    /// Convergence tolerance on the ELBO.
64    pub tol: f64,
65    /// Vocabulary size (may be 0; inferred from data if so).
66    pub vocab_size: usize,
67}
68
69impl Default for CtmConfig {
70    fn default() -> Self {
71        Self {
72            n_topics: 10,
73            max_iter: 100,
74            tol: 1e-4,
75            vocab_size: 0,
76        }
77    }
78}
79
80// ────────────────────────────────────────────────────────────────────────────
81// Result type
82// ────────────────────────────────────────────────────────────────────────────
83
84/// Output of a fitted Correlated Topic Model.
85#[derive(Debug, Clone)]
86pub struct CtmResult {
87    /// Topic-word probability matrix `K × V`. Each row sums to 1.
88    pub topic_word_matrix: Vec<Vec<f64>>,
89    /// Document-topic probability matrix `D × K`. Each row sums to 1.
90    pub doc_topic_matrix: Vec<Vec<f64>>,
91    /// Fitted prior mean µ (length K).
92    pub mu: Vec<f64>,
93    /// Fitted prior covariance Σ (K × K).
94    pub sigma: Vec<Vec<f64>>,
95    /// Approximate log-likelihood of the corpus under the fitted model.
96    pub log_likelihood: f64,
97}
98
99// ────────────────────────────────────────────────────────────────────────────
100// Model struct
101// ────────────────────────────────────────────────────────────────────────────
102
103/// Correlated Topic Model estimator.
104///
105/// Fit with [`CorrelatedTopicModel::fit`]; the fitted result is returned as
106/// a [`CtmResult`] value (the model itself is stateless after construction).
107pub struct CorrelatedTopicModel {
108    /// Model configuration.
109    pub config: CtmConfig,
110    /// Fitted result (populated after `fit`).
111    fitted: Option<CtmResult>,
112}
113
114impl CorrelatedTopicModel {
115    /// Construct a new (unfitted) CTM with the given configuration.
116    pub fn new(config: CtmConfig) -> Self {
117        Self {
118            config,
119            fitted: None,
120        }
121    }
122
123    /// Return a reference to the fitted result, if available.
124    pub fn fitted_result(&self) -> Option<&CtmResult> {
125        self.fitted.as_ref()
126    }
127
128    /// Fit the model and store the result internally, also returning it.
129    pub fn fit_and_store(
130        &mut self,
131        doc_counts_list: &[Vec<f64>],
132        vocab_size: usize,
133    ) -> Result<&CtmResult> {
134        let result = self.fit(doc_counts_list, vocab_size)?;
135        self.fitted = Some(result);
136        Ok(self.fitted.as_ref().expect("just set"))
137    }
138
139    /// Return the top-`n` words for each topic given a vocabulary.
140    ///
141    /// Requires the model to have been fitted (via `fit_and_store`).
142    pub fn top_words_from_fitted(&self, vocab: &[String], n: usize) -> Option<Vec<Vec<String>>> {
143        self.fitted
144            .as_ref()
145            .map(|r| top_words(&r.topic_word_matrix, vocab, n))
146    }
147
148    /// Compute the inter-topic correlation matrix from the fitted Σ.
149    pub fn correlation_matrix_from_fitted(&self) -> Option<Vec<Vec<f64>>> {
150        self.fitted
151            .as_ref()
152            .map(|r| topic_correlation_matrix(&r.sigma))
153    }
154}
155
156impl Default for CorrelatedTopicModel {
157    fn default() -> Self {
158        Self::new(CtmConfig::default())
159    }
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165
166    #[test]
167    fn ctm_default_config() {
168        let cfg = CtmConfig::default();
169        assert_eq!(cfg.n_topics, 10);
170        assert_eq!(cfg.max_iter, 100);
171        assert!((cfg.tol - 1e-4).abs() < 1e-12);
172    }
173
174    #[test]
175    fn ctm_model_default() {
176        let m = CorrelatedTopicModel::default();
177        assert_eq!(m.config.n_topics, 10);
178        assert!(m.fitted_result().is_none());
179    }
180
181    #[test]
182    fn ctm_fit_and_store() {
183        let mut model = CorrelatedTopicModel::new(CtmConfig {
184            n_topics: 2,
185            max_iter: 5,
186            tol: 1e-3,
187            vocab_size: 4,
188        });
189        let docs: Vec<Vec<f64>> = (0..4)
190            .map(|i| (0..4).map(|w| ((i + w) % 3) as f64).collect())
191            .collect();
192        model.fit_and_store(&docs, 4).expect("fit_and_store failed");
193        assert!(model.fitted_result().is_some());
194    }
195}