scirs2_text/dtm/mod.rs
1//! # Dynamic Topic Model (DTM)
2//!
3//! Implements the Dynamic Topic Model of Blei & Lafferty (2006), which extends
4//! LDA by modelling topic evolution over discrete time slices via a Gaussian
5//! state-space model.
6//!
7//! ## Model
8//!
9//! ```text
10//! β_{t,k} | β_{t-1,k} ~ N(β_{t-1,k}, σ² I) (topic word evolution)
11//! θ_d ~ Dir(α) (document-topic)
12//! z_{dn} ~ Categorical(θ_d) (topic assignment)
13//! w_{dn} ~ Categorical(β_{t,z_{dn}}) (word generation)
14//! ```
15//!
16//! Inference is performed via variational EM with a Kalman smoother on the
17//! topic-word parameters.
18//!
19//! ## Example
20//!
21//! ```rust
22//! use scirs2_text::dtm::{DynamicTopicModel, DtmConfig};
23//!
24//! let config = DtmConfig {
25//! n_topics: 2,
26//! n_time_slices: 3,
27//! max_iter: 5,
28//! sigma_sq: 0.1,
29//! alpha: 0.1,
30//! };
31//! let model = DynamicTopicModel::new(config);
32//!
33//! // 3 time slices, each with 4 documents of 6 words each
34//! let docs_by_time: Vec<Vec<Vec<f64>>> = (0..3)
35//! .map(|t| {
36//! (0..4)
37//! .map(|d| (0..6).map(|w| ((t + d + w) % 3) as f64).collect())
38//! .collect()
39//! })
40//! .collect();
41//!
42//! let result = model.fit(&docs_by_time, 6).expect("DTM fit failed");
43//! assert_eq!(result.topic_word_trajectories.len(), 2); // K topics
44//! ```
45
46pub mod inference;
47pub mod model;
48
49use crate::error::Result;
50
51// ────────────────────────────────────────────────────────────────────────────
52// Public re-exports
53// ────────────────────────────────────────────────────────────────────────────
54
55pub use inference::{kalman_backward, kalman_forward};
56pub use model::{top_words_at_time, topic_evolution};
57
58// ────────────────────────────────────────────────────────────────────────────
59// Configuration
60// ────────────────────────────────────────────────────────────────────────────
61
62/// Configuration for the Dynamic Topic Model.
63#[derive(Debug, Clone)]
64pub struct DtmConfig {
65 /// Number of latent topics K.
66 pub n_topics: usize,
67 /// Number of time slices T (may be 0; inferred from data if so).
68 pub n_time_slices: usize,
69 /// Maximum number of variational EM iterations.
70 pub max_iter: usize,
71 /// State-transition variance σ² for the Gaussian random walk.
72 pub sigma_sq: f64,
73 /// Dirichlet concentration parameter α for document-topic prior.
74 pub alpha: f64,
75}
76
77impl Default for DtmConfig {
78 fn default() -> Self {
79 Self {
80 n_topics: 10,
81 n_time_slices: 0,
82 max_iter: 50,
83 sigma_sq: 0.5,
84 alpha: 0.01,
85 }
86 }
87}
88
89// ────────────────────────────────────────────────────────────────────────────
90// Result type
91// ────────────────────────────────────────────────────────────────────────────
92
93/// Output of a fitted Dynamic Topic Model.
94#[derive(Debug, Clone)]
95pub struct DtmResult {
96 /// Topic-word trajectories `K × T × V`.
97 ///
98 /// `trajectories[k][t][w]` is the probability of word `w` under topic `k`
99 /// at time slice `t`. Each slice `trajectories[k][t]` sums to 1.
100 pub topic_word_trajectories: Vec<Vec<Vec<f64>>>,
101 /// Flattened document-topic distribution (all documents across all time
102 /// slices concatenated). Each row sums to 1.
103 pub doc_topic_matrix: Vec<Vec<f64>>,
104}
105
106// ────────────────────────────────────────────────────────────────────────────
107// Model struct
108// ────────────────────────────────────────────────────────────────────────────
109
110/// Dynamic Topic Model estimator.
111///
112/// Fit via [`DynamicTopicModel::fit`]; the result is returned as a [`DtmResult`].
113pub struct DynamicTopicModel {
114 /// Model configuration.
115 pub config: DtmConfig,
116 /// Fitted result (populated after `fit_and_store`).
117 fitted: Option<DtmResult>,
118}
119
120impl DynamicTopicModel {
121 /// Construct a new (unfitted) DTM with the given configuration.
122 pub fn new(config: DtmConfig) -> Self {
123 Self {
124 config,
125 fitted: None,
126 }
127 }
128
129 /// Return a reference to the fitted result, if available.
130 pub fn fitted_result(&self) -> Option<&DtmResult> {
131 self.fitted.as_ref()
132 }
133
134 /// Fit the model and store the result internally, also returning it.
135 pub fn fit_and_store(
136 &mut self,
137 docs_by_time: &[Vec<Vec<f64>>],
138 vocab_size: usize,
139 ) -> Result<&DtmResult> {
140 let result = self.fit(docs_by_time, vocab_size)?;
141 self.fitted = Some(result);
142 Ok(self.fitted.as_ref().expect("just set"))
143 }
144
145 /// Return the top-`n` words for each topic at time `t` using the fitted model.
146 pub fn top_words_at(&self, t: usize, vocab: &[String], n: usize) -> Option<Vec<Vec<String>>> {
147 self.fitted
148 .as_ref()
149 .map(|r| top_words_at_time(&r.topic_word_trajectories, t, vocab, n))
150 }
151
152 /// Return the evolution of word `word_id` in topic `topic_id` over time.
153 pub fn word_evolution(&self, topic_id: usize, word_id: usize) -> Option<Vec<f64>> {
154 self.fitted
155 .as_ref()
156 .map(|r| topic_evolution(&r.topic_word_trajectories, topic_id, word_id))
157 }
158}
159
160impl Default for DynamicTopicModel {
161 fn default() -> Self {
162 Self::new(DtmConfig::default())
163 }
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169
170 #[test]
171 fn dtm_default_config() {
172 let cfg = DtmConfig::default();
173 assert_eq!(cfg.n_topics, 10);
174 assert_eq!(cfg.max_iter, 50);
175 assert!((cfg.sigma_sq - 0.5).abs() < 1e-12);
176 assert!((cfg.alpha - 0.01).abs() < 1e-12);
177 }
178
179 #[test]
180 fn dtm_default_model() {
181 let m = DynamicTopicModel::default();
182 assert_eq!(m.config.n_topics, 10);
183 assert!(m.fitted_result().is_none());
184 }
185
186 #[test]
187 fn dtm_fit_and_store() {
188 let mut model = DynamicTopicModel::new(DtmConfig {
189 n_topics: 2,
190 n_time_slices: 2,
191 max_iter: 3,
192 sigma_sq: 0.1,
193 alpha: 0.1,
194 });
195 let docs_by_time: Vec<Vec<Vec<f64>>> = (0..2)
196 .map(|t| {
197 (0..3)
198 .map(|d| (0..4).map(|w| ((t + d + w) % 3) as f64).collect())
199 .collect()
200 })
201 .collect();
202 model.fit_and_store(&docs_by_time, 4).expect("fit failed");
203 assert!(model.fitted_result().is_some());
204 }
205
206 #[test]
207 fn dtm_top_words_at_after_fit() {
208 let mut model = DynamicTopicModel::new(DtmConfig {
209 n_topics: 2,
210 n_time_slices: 2,
211 max_iter: 3,
212 sigma_sq: 0.1,
213 alpha: 0.1,
214 });
215 let docs_by_time: Vec<Vec<Vec<f64>>> = (0..2)
216 .map(|t| {
217 (0..3)
218 .map(|d| (0..5).map(|w| ((t + d + w) % 3) as f64).collect())
219 .collect()
220 })
221 .collect();
222 model.fit_and_store(&docs_by_time, 5).expect("fit failed");
223 let vocab: Vec<String> = (0..5).map(|i| format!("w{i}")).collect();
224 let tw = model.top_words_at(0, &vocab, 3).expect("no fitted result");
225 assert_eq!(tw.len(), 2); // K topics
226 assert_eq!(tw[0].len(), 3); // n words
227 }
228
229 #[test]
230 fn dtm_word_evolution_length_equals_t() {
231 let mut model = DynamicTopicModel::new(DtmConfig {
232 n_topics: 2,
233 n_time_slices: 4,
234 max_iter: 3,
235 sigma_sq: 0.1,
236 alpha: 0.1,
237 });
238 let docs_by_time: Vec<Vec<Vec<f64>>> = (0..4)
239 .map(|t| {
240 (0..2)
241 .map(|d| (0..5).map(|w| ((t + d + w) % 3) as f64).collect())
242 .collect()
243 })
244 .collect();
245 model.fit_and_store(&docs_by_time, 5).expect("fit failed");
246 let ev = model.word_evolution(0, 2).expect("no fitted result");
247 assert_eq!(ev.len(), 4);
248 }
249}