scirs2_text/ctm/mod.rs
1//! # Correlated Topic Model (CTM)
2//!
3//! Implements the Correlated Topic Model of Blei & Lafferty (2006), which uses
4//! a logistic-normal prior over topic proportions to capture inter-topic
5//! correlations that a plain Dirichlet prior cannot represent.
6//!
7//! ## Overview
8//!
9//! Unlike LDA, CTM models the document-topic distribution as:
10//!
11//! ```text
12//! η_d ~ N(µ, Σ)
13//! θ_d = softmax(η_d)
14//! w_dn ~ Multinomial(θ_d, β)
15//! ```
16//!
17//! Inference is performed via mean-field variational EM.
18//!
19//! ## Example
20//!
21//! ```rust
22//! use scirs2_text::ctm::{CorrelatedTopicModel, CtmConfig};
23//!
24//! let config = CtmConfig {
25//! n_topics: 3,
26//! max_iter: 20,
27//! tol: 1e-4,
28//! vocab_size: 10,
29//! };
30//! let model = CorrelatedTopicModel::new(config);
31//!
32//! // Documents as word-count vectors (length = vocab_size)
33//! let docs: Vec<Vec<f64>> = (0..5)
34//! .map(|i| (0..10).map(|w| ((i * 3 + w) % 4) as f64).collect())
35//! .collect();
36//!
37//! let result = model.fit(&docs, 10).expect("CTM fit failed");
38//! assert_eq!(result.topic_word_matrix.len(), 3);
39//! ```
40
41pub mod inference;
42pub mod model;
43
44use crate::error::Result;
45
46// ────────────────────────────────────────────────────────────────────────────
47// Public re-exports
48// ────────────────────────────────────────────────────────────────────────────
49
50pub use model::{log_likelihood, softmax, top_words, topic_correlation_matrix};
51
52// ────────────────────────────────────────────────────────────────────────────
53// Configuration
54// ────────────────────────────────────────────────────────────────────────────
55
56/// Configuration for the Correlated Topic Model.
57#[derive(Debug, Clone)]
58pub struct CtmConfig {
59 /// Number of latent topics.
60 pub n_topics: usize,
61 /// Maximum number of EM iterations.
62 pub max_iter: usize,
63 /// Convergence tolerance on the ELBO.
64 pub tol: f64,
65 /// Vocabulary size (may be 0; inferred from data if so).
66 pub vocab_size: usize,
67}
68
69impl Default for CtmConfig {
70 fn default() -> Self {
71 Self {
72 n_topics: 10,
73 max_iter: 100,
74 tol: 1e-4,
75 vocab_size: 0,
76 }
77 }
78}
79
80// ────────────────────────────────────────────────────────────────────────────
81// Result type
82// ────────────────────────────────────────────────────────────────────────────
83
84/// Output of a fitted Correlated Topic Model.
85#[derive(Debug, Clone)]
86pub struct CtmResult {
87 /// Topic-word probability matrix `K × V`. Each row sums to 1.
88 pub topic_word_matrix: Vec<Vec<f64>>,
89 /// Document-topic probability matrix `D × K`. Each row sums to 1.
90 pub doc_topic_matrix: Vec<Vec<f64>>,
91 /// Fitted prior mean µ (length K).
92 pub mu: Vec<f64>,
93 /// Fitted prior covariance Σ (K × K).
94 pub sigma: Vec<Vec<f64>>,
95 /// Approximate log-likelihood of the corpus under the fitted model.
96 pub log_likelihood: f64,
97}
98
99// ────────────────────────────────────────────────────────────────────────────
100// Model struct
101// ────────────────────────────────────────────────────────────────────────────
102
103/// Correlated Topic Model estimator.
104///
105/// Fit with [`CorrelatedTopicModel::fit`]; the fitted result is returned as
106/// a [`CtmResult`] value (the model itself is stateless after construction).
107pub struct CorrelatedTopicModel {
108 /// Model configuration.
109 pub config: CtmConfig,
110 /// Fitted result (populated after `fit`).
111 fitted: Option<CtmResult>,
112}
113
114impl CorrelatedTopicModel {
115 /// Construct a new (unfitted) CTM with the given configuration.
116 pub fn new(config: CtmConfig) -> Self {
117 Self {
118 config,
119 fitted: None,
120 }
121 }
122
123 /// Return a reference to the fitted result, if available.
124 pub fn fitted_result(&self) -> Option<&CtmResult> {
125 self.fitted.as_ref()
126 }
127
128 /// Fit the model and store the result internally, also returning it.
129 pub fn fit_and_store(
130 &mut self,
131 doc_counts_list: &[Vec<f64>],
132 vocab_size: usize,
133 ) -> Result<&CtmResult> {
134 let result = self.fit(doc_counts_list, vocab_size)?;
135 self.fitted = Some(result);
136 Ok(self.fitted.as_ref().expect("just set"))
137 }
138
139 /// Return the top-`n` words for each topic given a vocabulary.
140 ///
141 /// Requires the model to have been fitted (via `fit_and_store`).
142 pub fn top_words_from_fitted(&self, vocab: &[String], n: usize) -> Option<Vec<Vec<String>>> {
143 self.fitted
144 .as_ref()
145 .map(|r| top_words(&r.topic_word_matrix, vocab, n))
146 }
147
148 /// Compute the inter-topic correlation matrix from the fitted Σ.
149 pub fn correlation_matrix_from_fitted(&self) -> Option<Vec<Vec<f64>>> {
150 self.fitted
151 .as_ref()
152 .map(|r| topic_correlation_matrix(&r.sigma))
153 }
154}
155
156impl Default for CorrelatedTopicModel {
157 fn default() -> Self {
158 Self::new(CtmConfig::default())
159 }
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165
166 #[test]
167 fn ctm_default_config() {
168 let cfg = CtmConfig::default();
169 assert_eq!(cfg.n_topics, 10);
170 assert_eq!(cfg.max_iter, 100);
171 assert!((cfg.tol - 1e-4).abs() < 1e-12);
172 }
173
174 #[test]
175 fn ctm_model_default() {
176 let m = CorrelatedTopicModel::default();
177 assert_eq!(m.config.n_topics, 10);
178 assert!(m.fitted_result().is_none());
179 }
180
181 #[test]
182 fn ctm_fit_and_store() {
183 let mut model = CorrelatedTopicModel::new(CtmConfig {
184 n_topics: 2,
185 max_iter: 5,
186 tol: 1e-3,
187 vocab_size: 4,
188 });
189 let docs: Vec<Vec<f64>> = (0..4)
190 .map(|i| (0..4).map(|w| ((i + w) % 3) as f64).collect())
191 .collect();
192 model.fit_and_store(&docs, 4).expect("fit_and_store failed");
193 assert!(model.fitted_result().is_some());
194 }
195}