Skip to main content

scirs2_text/dtm/
inference.rs

1//! Variational inference for the Dynamic Topic Model (DTM).
2//!
3//! Implements:
4//! - Variational Kalman filter (forward pass)
5//! - RTS smoother (backward pass)
6//! - Per-document E-step (Dirichlet-LDA-style variational update)
7//! - Global M-step (Kalman smoother on sufficient statistics)
8//! - Main `fit` routine on `DynamicTopicModel`
9
10use crate::dtm::model::normalise_to_simplex;
11use crate::dtm::{DtmConfig, DtmResult, DynamicTopicModel};
12use crate::error::{Result, TextError};
13
14// ────────────────────────────────────────────────────────────────────────────
15// Variational Kalman filter & RTS smoother
16// ────────────────────────────────────────────────────────────────────────────
17
18/// Variational Kalman forward pass for a single word dimension.
19///
20/// The model is:
21/// ```text
22///   β_{t,k,w} ~ N(β_{t-1,k,w}, σ²)   (state transition)
23///   y_{t,k,w} ~ N(β_{t,k,w}, obs_noise)  (observation)
24/// ```
25///
26/// # Arguments
27/// * `observations`  – observed sufficient statistics at each time step (length T)
28/// * `sigma_sq`      – state transition variance σ²
29/// * `obs_noise`     – observation noise variance
30///
31/// # Returns
32/// Tuple of four vectors (all length T):
33/// `(filter_means, filter_vars, pred_means, pred_vars)`
34pub fn kalman_forward(
35    observations: &[f64],
36    sigma_sq: f64,
37    obs_noise: f64,
38) -> (Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>) {
39    let t = observations.len();
40    if t == 0 {
41        return (Vec::new(), Vec::new(), Vec::new(), Vec::new());
42    }
43
44    let mut filter_means = vec![0.0_f64; t];
45    let mut filter_vars = vec![0.0_f64; t];
46    let mut pred_means = vec![0.0_f64; t];
47    let mut pred_vars = vec![0.0_f64; t];
48
49    // Initialise: broad prior
50    let prior_mean = observations[0];
51    let prior_var = sigma_sq + obs_noise;
52
53    // t = 0: prediction from prior
54    pred_means[0] = prior_mean;
55    pred_vars[0] = prior_var;
56
57    // t = 0: update
58    let k0 = pred_vars[0] / (pred_vars[0] + obs_noise);
59    filter_means[0] = pred_means[0] + k0 * (observations[0] - pred_means[0]);
60    filter_vars[0] = (1.0 - k0) * pred_vars[0];
61
62    for s in 1..t {
63        // Predict
64        pred_means[s] = filter_means[s - 1];
65        pred_vars[s] = filter_vars[s - 1] + sigma_sq;
66
67        // Update (Kalman gain)
68        let gain = pred_vars[s] / (pred_vars[s] + obs_noise);
69        filter_means[s] = pred_means[s] + gain * (observations[s] - pred_means[s]);
70        filter_vars[s] = (1.0 - gain) * pred_vars[s];
71    }
72
73    (filter_means, filter_vars, pred_means, pred_vars)
74}
75
76/// RTS (Rauch-Tung-Striebel) smoother — backward pass.
77///
78/// # Arguments
79/// * `filter_means` – Kalman filter means (length T)
80/// * `filter_vars`  – Kalman filter variances (length T)
81/// * `pred_means`   – Kalman prediction means (length T, from forward pass)
82/// * `pred_vars`    – Kalman prediction variances (length T, from forward pass)
83/// * `sigma_sq`     – state transition variance σ²
84///
85/// # Returns
86/// `(smoother_means, smoother_vars)` each of length T.
87pub fn kalman_backward(
88    filter_means: &[f64],
89    filter_vars: &[f64],
90    pred_means: &[f64],
91    pred_vars: &[f64],
92    sigma_sq: f64,
93) -> (Vec<f64>, Vec<f64>) {
94    let t = filter_means.len();
95    if t == 0 {
96        return (Vec::new(), Vec::new());
97    }
98
99    let mut smoother_means = vec![0.0_f64; t];
100    let mut smoother_vars = vec![0.0_f64; t];
101
102    // Initialise from the last filter state
103    smoother_means[t - 1] = filter_means[t - 1];
104    smoother_vars[t - 1] = filter_vars[t - 1];
105
106    for s in (0..t - 1).rev() {
107        let pred_var_next = pred_vars[s + 1].max(1e-15);
108        // Smoother gain
109        let g = filter_vars[s] / pred_var_next;
110        // Update means: m_s^smooth = m_s + G*(m_{s+1}^smooth - pred_m_{s+1})
111        smoother_means[s] = filter_means[s] + g * (smoother_means[s + 1] - pred_means[s + 1]);
112        // Update vars:  v_s^smooth = v_s + G²*(v_{s+1}^smooth - pred_v_{s+1})
113        smoother_vars[s] = filter_vars[s] + g * g * (smoother_vars[s + 1] - pred_var_next);
114        // Clamp to non-negative
115        smoother_vars[s] = smoother_vars[s].max(1e-15);
116
117        let _ = sigma_sq; // used in pred_var derivation above
118    }
119
120    (smoother_means, smoother_vars)
121}
122
123// ────────────────────────────────────────────────────────────────────────────
124// E-step: per-document Dirichlet-LDA variational update
125// ────────────────────────────────────────────────────────────────────────────
126
127/// Dirichlet digamma approximation: ψ(x) ≈ ln(x) - 1/(2x) for x > 0.
128fn digamma(x: f64) -> f64 {
129    if x <= 0.0 {
130        return -1e10;
131    }
132    // Abramowitz & Stegun approximation (good for x > 1; recurse otherwise)
133    let mut z = x;
134    let mut result = 0.0_f64;
135    while z < 6.0 {
136        result -= 1.0 / z;
137        z += 1.0;
138    }
139    result += z.ln() - 0.5 / z - 1.0 / (12.0 * z * z) + 1.0 / (120.0 * z * z * z * z)
140        - 1.0 / (252.0 * z * z * z * z * z * z);
141    result
142}
143
144/// Variational E-step for a single document using Dirichlet-LDA updates.
145///
146/// # Arguments
147/// * `doc_counts`   – word-count vector (length V)
148/// * `gamma`        – Dirichlet variational parameter (length K, updated in place)
149/// * `phi`          – topic assignment probabilities per word (K × V, updated in place)
150/// * `beta_t`       – topic-word distributions at this time slice (K × V)
151/// * `alpha`        – Dirichlet prior concentration
152/// * `max_inner`    – max inner iterations
153fn e_step_doc(
154    doc_counts: &[f64],
155    gamma: &mut [f64],
156    phi: &mut [Vec<f64>],
157    beta_t: &[Vec<f64>],
158    alpha: f64,
159    max_inner: usize,
160) {
161    let k = gamma.len();
162    let vocab = doc_counts.len();
163
164    for _ in 0..max_inner {
165        // Update phi_{dkw} ∝ beta[k][w] * exp(ψ(gamma[k]))
166        let dg: Vec<f64> = gamma.iter().map(|&g| digamma(g)).collect();
167        for w in 0..vocab {
168            if doc_counts[w] <= 0.0 {
169                continue;
170            }
171            let mut row_sum = 0.0_f64;
172            for t in 0..k {
173                let beta_val = if t < beta_t.len() && w < beta_t[t].len() {
174                    beta_t[t][w].max(1e-15)
175                } else {
176                    1e-15
177                };
178                phi[t][w] = beta_val * dg[t].exp();
179                row_sum += phi[t][w];
180            }
181            if row_sum > 1e-15 {
182                for t in 0..k {
183                    phi[t][w] /= row_sum;
184                }
185            }
186        }
187
188        // Update gamma[k] = alpha + Σ_w c_w * phi[k][w]
189        for t in 0..k {
190            let weighted: f64 = (0..vocab).map(|w| doc_counts[w] * phi[t][w]).sum();
191            gamma[t] = alpha + weighted;
192        }
193    }
194}
195
196// ────────────────────────────────────────────────────────────────────────────
197// Main fit routine
198// ────────────────────────────────────────────────────────────────────────────
199
200impl DynamicTopicModel {
201    /// Fit the DTM to a corpus organised by time slice.
202    ///
203    /// # Arguments
204    /// * `docs_by_time` – `T` lists of documents; each document is a word-count vector of length V
205    /// * `vocab_size`   – vocabulary size V
206    ///
207    /// # Returns
208    /// A [`DtmResult`] with topic-word trajectories and doc-topic distributions.
209    pub fn fit(&self, docs_by_time: &[Vec<Vec<f64>>], vocab_size: usize) -> Result<DtmResult> {
210        let n_time = docs_by_time.len();
211        if n_time == 0 {
212            return Err(TextError::InvalidInput(
213                "Empty time-slice collection".into(),
214            ));
215        }
216
217        let k = self.config.n_topics;
218        let v = if vocab_size > 0 {
219            vocab_size
220        } else {
221            docs_by_time
222                .iter()
223                .flat_map(|slice| slice.iter())
224                .map(|d| d.len())
225                .max()
226                .unwrap_or(1)
227        };
228        let sigma_sq = self.config.sigma_sq;
229        let alpha = self.config.alpha;
230        let obs_noise = sigma_sq * 0.1_f64; // heuristic observation noise
231
232        // ── Initialise β trajectories: K × T × V ──────────────────────────
233        // Deterministic perturbation for reproducibility
234        let mut trajectories: Vec<Vec<Vec<f64>>> = (0..k)
235            .map(|ki| {
236                (0..n_time)
237                    .map(|ti| {
238                        let mut row: Vec<f64> = (0..v)
239                            .map(|wi| {
240                                1.0 / v as f64
241                                    + ((ki * 1009 + ti * 997 + wi * 991) % 1000) as f64 * 1e-5
242                            })
243                            .collect();
244                        normalise_to_simplex(&mut row);
245                        row
246                    })
247                    .collect()
248            })
249            .collect();
250
251        // ── Flat doc-topic storage ─────────────────────────────────────────
252        // doc_topic[t][d] = gamma vector (length K)
253        let mut doc_gammas: Vec<Vec<Vec<f64>>> = docs_by_time
254            .iter()
255            .map(|slice| {
256                slice
257                    .iter()
258                    .map(|_| vec![alpha + 1.0_f64 / k as f64; k])
259                    .collect::<Vec<_>>()
260            })
261            .collect();
262
263        for _iter in 0..self.config.max_iter {
264            // ── E-step ──────────────────────────────────────────────────────
265            // Collect sufficient statistics for each topic-word pair across time
266            // suff_stats[k][t][w] = expected count of word w in topic k at time t
267            let mut suff_stats: Vec<Vec<Vec<f64>>> = vec![vec![vec![0.0_f64; v]; n_time]; k];
268
269            for (ti, slice) in docs_by_time.iter().enumerate() {
270                let beta_t: Vec<Vec<f64>> = (0..k).map(|ki| trajectories[ki][ti].clone()).collect();
271
272                for (di, doc_counts) in slice.iter().enumerate() {
273                    let mut phi = vec![vec![0.0_f64; v]; k];
274                    e_step_doc(
275                        doc_counts,
276                        &mut doc_gammas[ti][di],
277                        &mut phi,
278                        &beta_t,
279                        alpha,
280                        5,
281                    );
282                    // Accumulate into suff_stats
283                    for ki in 0..k {
284                        for w in 0..v {
285                            suff_stats[ki][ti][w] +=
286                                doc_counts.get(w).copied().unwrap_or(0.0) * phi[ki][w];
287                        }
288                    }
289                }
290            }
291
292            // ── M-step: Kalman smoother per (k, w) ──────────────────────────
293            for ki in 0..k {
294                for w in 0..v {
295                    // Collect observations for this (k, w) across time
296                    let obs: Vec<f64> = (0..n_time)
297                        .map(|ti| {
298                            let total: f64 = (0..v).map(|ww| suff_stats[ki][ti][ww]).sum();
299                            if total > 1e-15 {
300                                (suff_stats[ki][ti][w] / total).max(1e-15)
301                            } else {
302                                1.0 / v as f64
303                            }
304                        })
305                        .collect();
306
307                    let (fm, fv, pm, pv) = kalman_forward(&obs, sigma_sq, obs_noise);
308                    let (sm, _sv) = kalman_backward(&fm, &fv, &pm, &pv, sigma_sq);
309
310                    for ti in 0..n_time {
311                        trajectories[ki][ti][w] = sm[ti].max(1e-15);
312                    }
313                }
314
315                // Re-normalise each time slice
316                for ti in 0..n_time {
317                    normalise_to_simplex(&mut trajectories[ki][ti]);
318                }
319            }
320        }
321
322        // ── Build doc-topic matrix ─────────────────────────────────────────
323        // Normalise each gamma to get θ_d = gamma / sum(gamma)
324        let mut doc_topic_matrix: Vec<Vec<f64>> = Vec::new();
325        for slice_gammas in &doc_gammas {
326            for gamma in slice_gammas {
327                let s: f64 = gamma.iter().sum();
328                let theta: Vec<f64> = gamma.iter().map(|&g| g / s.max(1e-15)).collect();
329                doc_topic_matrix.push(theta);
330            }
331        }
332
333        Ok(DtmResult {
334            topic_word_trajectories: trajectories,
335            doc_topic_matrix,
336        })
337    }
338}
339
340// ────────────────────────────────────────────────────────────────────────────
341// Tests
342// ────────────────────────────────────────────────────────────────────────────
343
344#[cfg(test)]
345mod tests {
346    use super::*;
347    use crate::dtm::{DtmConfig, DynamicTopicModel};
348
349    fn make_slice(n_docs: usize, vocab: usize, seed: usize) -> Vec<Vec<f64>> {
350        (0..n_docs)
351            .map(|d| {
352                (0..vocab)
353                    .map(|w| ((d * 3 + w * 7 + seed) % 5) as f64)
354                    .collect()
355            })
356            .collect()
357    }
358
359    #[test]
360    fn kalman_forward_correct_shape() {
361        let obs = vec![0.1_f64, 0.15, 0.12, 0.18, 0.14];
362        let (fm, fv, pm, pv) = kalman_forward(&obs, 0.01, 0.001);
363        assert_eq!(fm.len(), 5);
364        assert_eq!(fv.len(), 5);
365        assert_eq!(pm.len(), 5);
366        assert_eq!(pv.len(), 5);
367    }
368
369    #[test]
370    fn kalman_backward_smoother_variance_le_filter_variance() {
371        let obs = vec![0.1_f64, 0.15, 0.12, 0.18, 0.14, 0.13];
372        let (fm, fv, pm, pv) = kalman_forward(&obs, 0.01, 0.001);
373        let (_, sv) = kalman_backward(&fm, &fv, &pm, &pv, 0.01);
374        // Smoother variance should not exceed filter variance (RTS property)
375        for (i, (&sv_i, &fv_i)) in sv.iter().zip(fv.iter()).enumerate() {
376            assert!(
377                sv_i <= fv_i + 1e-10,
378                "smoother_var[{i}]={sv_i} > filter_var[{i}]={fv_i}"
379            );
380        }
381    }
382
383    #[test]
384    fn kalman_roundtrip_recovers_trajectory() {
385        // Constant trajectory: smoother should converge near the constant
386        let truth = 0.2_f64;
387        let obs: Vec<f64> = vec![truth; 10];
388        let (fm, fv, pm, pv) = kalman_forward(&obs, 1e-4, 1e-3);
389        let (sm, _) = kalman_backward(&fm, &fv, &pm, &pv, 1e-4);
390        for (i, &m) in sm.iter().enumerate() {
391            assert!((m - truth).abs() < 0.05, "smoother[{i}]={m}, truth={truth}");
392        }
393    }
394
395    #[test]
396    fn dtm_fit_trajectories_shape() {
397        let config = DtmConfig {
398            n_topics: 2,
399            n_time_slices: 3,
400            max_iter: 5,
401            sigma_sq: 0.1,
402            alpha: 0.1,
403        };
404        let model = DynamicTopicModel::new(config);
405        let docs_by_time: Vec<Vec<Vec<f64>>> = (0..3).map(|t| make_slice(4, 5, t)).collect();
406        let res = model.fit(&docs_by_time, 5).expect("fit failed");
407        // shape K × T × V
408        assert_eq!(res.topic_word_trajectories.len(), 2);
409        assert_eq!(res.topic_word_trajectories[0].len(), 3);
410        assert_eq!(res.topic_word_trajectories[0][0].len(), 5);
411    }
412
413    #[test]
414    fn dtm_fit_doc_topic_rows_sum_to_one() {
415        let config = DtmConfig {
416            n_topics: 2,
417            n_time_slices: 3,
418            max_iter: 3,
419            sigma_sq: 0.1,
420            alpha: 0.1,
421        };
422        let model = DynamicTopicModel::new(config);
423        let docs_by_time: Vec<Vec<Vec<f64>>> = (0..3).map(|t| make_slice(3, 5, t)).collect();
424        let res = model.fit(&docs_by_time, 5).expect("fit failed");
425        for (d, row) in res.doc_topic_matrix.iter().enumerate() {
426            let s: f64 = row.iter().sum();
427            assert!((s - 1.0).abs() < 1e-6, "doc {d} topic sum = {s}");
428        }
429    }
430
431    #[test]
432    fn dtm_fit_trajectories_row_sums_to_one() {
433        let config = DtmConfig {
434            n_topics: 2,
435            n_time_slices: 3,
436            max_iter: 3,
437            sigma_sq: 0.1,
438            alpha: 0.1,
439        };
440        let model = DynamicTopicModel::new(config);
441        let docs_by_time: Vec<Vec<Vec<f64>>> = (0..3).map(|t| make_slice(3, 5, t)).collect();
442        let res = model.fit(&docs_by_time, 5).expect("fit failed");
443        for (ki, topic_traj) in res.topic_word_trajectories.iter().enumerate() {
444            for (ti, row) in topic_traj.iter().enumerate() {
445                let s: f64 = row.iter().sum();
446                assert!(
447                    (s - 1.0).abs() < 1e-4,
448                    "topic {ki} time {ti} word sum = {s}"
449                );
450            }
451        }
452    }
453}