Skip to main content

scirs2_text/ctm/
inference.rs

1//! Variational inference for the Correlated Topic Model (CTM).
2//!
3//! Implements:
4//! - Logistic-normal log-likelihood helper
5//! - Per-document E-step (coordinate ascent on variational parameters)
6//! - Global M-step (update µ, Σ, β)
7//! - Cholesky / LDL^T matrix inverse for small K×K matrices
8//! - Main `fit` routine on `CorrelatedTopicModel`
9
10use crate::ctm::model::softmax;
11use crate::ctm::{CorrelatedTopicModel, CtmConfig, CtmResult};
12use crate::error::{Result, TextError};
13
14// ────────────────────────────────────────────────────────────────────────────
15// Small-matrix helpers
16// ────────────────────────────────────────────────────────────────────────────
17
18/// Compute the inverse of a symmetric positive-definite K×K matrix via
19/// Cholesky factorisation (LDL^T variant, no external BLAS needed).
20///
21/// Returns `None` if the matrix is singular or not positive-definite.
22pub fn cholesky_inverse(a: &[Vec<f64>]) -> Option<Vec<Vec<f64>>> {
23    let k = a.len();
24    // Cholesky: A = L L^T  (lower triangular L)
25    let mut l = vec![vec![0.0_f64; k]; k];
26    for i in 0..k {
27        for j in 0..=i {
28            let mut sum = a[i][j];
29            for p in 0..j {
30                sum -= l[i][p] * l[j][p];
31            }
32            if i == j {
33                if sum <= 0.0 {
34                    return None; // not positive-definite
35                }
36                l[i][j] = sum.sqrt();
37            } else {
38                l[i][j] = sum / l[j][j];
39            }
40        }
41    }
42    // Compute L^{-1} by forward substitution
43    let mut l_inv = vec![vec![0.0_f64; k]; k];
44    for i in 0..k {
45        l_inv[i][i] = 1.0 / l[i][i];
46        for j in 0..i {
47            let mut sum = 0.0_f64;
48            for p in j..i {
49                sum -= l[i][p] * l_inv[p][j];
50            }
51            l_inv[i][j] = sum / l[i][i];
52        }
53    }
54    // A^{-1} = (L^{-1})^T  L^{-1}
55    let mut inv = vec![vec![0.0_f64; k]; k];
56    for i in 0..k {
57        for j in 0..k {
58            let mut s = 0.0_f64;
59            for p in 0..k {
60                s += l_inv[p][i] * l_inv[p][j];
61            }
62            inv[i][j] = s;
63        }
64    }
65    Some(inv)
66}
67
68/// Add a small diagonal regularisation so Σ is always positive-definite.
69fn regularise_sigma(sigma: &mut [Vec<f64>], eps: f64) {
70    let k = sigma.len();
71    for i in 0..k {
72        sigma[i][i] += eps;
73    }
74}
75
76// ────────────────────────────────────────────────────────────────────────────
77// Logistic-normal log-likelihood
78// ────────────────────────────────────────────────────────────────────────────
79
80/// Evaluate the Gaussian log-density of `eta` under N(µ, Σ^{-1} given as `sigma_inv`).
81///
82/// `log p(eta) = -½ (eta-µ)^T Σ^{-1} (eta-µ) + const`
83pub fn logistic_normal_ll(eta: &[f64], mu: &[f64], sigma_inv: &[Vec<f64>]) -> f64 {
84    let k = eta.len();
85    let mut ll = 0.0_f64;
86    for i in 0..k {
87        let di = eta[i] - mu[i];
88        for j in 0..k {
89            let dj = eta[j] - mu[j];
90            ll -= 0.5 * di * sigma_inv[i][j] * dj;
91        }
92    }
93    ll
94}
95
96// ────────────────────────────────────────────────────────────────────────────
97// Sufficient statistics from variational posterior
98// ────────────────────────────────────────────────────────────────────────────
99
100/// Compute expected topic proportions (θ) via the softmax of the variational
101/// mean `nu`, using Monte-Carlo approximation with a fixed number of samples.
102///
103/// For efficiency we use a first-order delta method approximation instead:
104/// `E[softmax(nu + eps)] ≈ softmax(nu)` which is exact in the limit of small σ².
105fn expected_theta(nu: &[f64], _sigma2: &[f64]) -> Vec<f64> {
106    softmax(nu)
107}
108
109// ────────────────────────────────────────────────────────────────────────────
110// Per-document E-step
111// ────────────────────────────────────────────────────────────────────────────
112
113/// Perform coordinate ascent on the variational parameters (nu, sigma2) for
114/// a single document.
115///
116/// # Arguments
117/// * `doc_counts` – word count vector (length V)
118/// * `nu`         – variational mean (updated in place, length K)
119/// * `sigma2`     – variational diagonal variance (updated in place, length K)
120/// * `mu`         – prior mean (length K)
121/// * `sigma_inv`  – prior precision matrix (K×K)
122/// * `beta`       – topic-word matrix (K×V)
123/// * `max_inner`  – max coordinate-ascent iterations
124///
125/// Returns the approximate ELBO contribution for this document.
126pub fn e_step_doc(
127    doc_counts: &[f64],
128    nu: &mut [f64],
129    sigma2: &mut [f64],
130    mu: &[f64],
131    sigma_inv: &[Vec<f64>],
132    beta: &[Vec<f64>],
133    max_inner: usize,
134) -> f64 {
135    let k = nu.len();
136    let vocab = doc_counts.len();
137    let n_words: f64 = doc_counts.iter().sum();
138
139    for _ in 0..max_inner {
140        let theta = expected_theta(nu, sigma2);
141
142        // ── Update sigma2_k (closed form) ──────────────────────────────────
143        // ELBO w.r.t. sigma2_k: -½ σ_inv[k][k] sigma2_k + ½ log(sigma2_k) + entropy
144        // Optimum: sigma2_k = 1 / sigma_inv[k][k]
145        // (Ignoring the expected Hessian of the log-normaliser for simplicity)
146        for t in 0..k {
147            let prec = sigma_inv[t][t].max(1e-10);
148            sigma2[t] = (1.0 / prec).max(1e-8);
149        }
150
151        // ── Update nu_k via Newton step ────────────────────────────────────
152        // ELBO gradient w.r.t. nu_k:
153        //   grad_k = Σ_w c_w * (phi_kw - theta_k) - Σ_j σ_inv[k][j] (nu_j - mu_j)
154        // where phi_kw = theta_k * beta[k][w] / (Σ_t theta_t * beta[t][w])
155        //
156        // Hessian diagonal (diagonal approx): -(n * theta_k (1-theta_k) + σ_inv[k][k])
157        for t in 0..k {
158            // Compute gradient
159            let mut grad = 0.0_f64;
160
161            // Word-model term
162            for w in 0..vocab {
163                if doc_counts[w] <= 0.0 {
164                    continue;
165                }
166                let mut mix = 0.0_f64;
167                for s in 0..k {
168                    if s < beta.len() && w < beta[s].len() {
169                        mix += theta[s] * beta[s][w];
170                    }
171                }
172                if mix > 1e-15 {
173                    let phi = if t < beta.len() && w < beta[t].len() {
174                        theta[t] * beta[t][w] / mix
175                    } else {
176                        0.0
177                    };
178                    grad += doc_counts[w] * (phi - theta[t]);
179                }
180            }
181
182            // Prior term: -Σ_j σ_inv[t][j] (nu_j - mu_j)
183            for j in 0..k {
184                grad -= sigma_inv[t][j] * (nu[j] - mu[j]);
185            }
186
187            // Diagonal Hessian approximation
188            let hess = -(n_words * theta[t] * (1.0 - theta[t]) + sigma_inv[t][t])
189                .abs()
190                .max(1e-10);
191
192            // Damped Newton step
193            let step = (grad / hess).clamp(-2.0, 2.0);
194            nu[t] -= step;
195        }
196    }
197
198    // ── Compute approximate ELBO for this document ─────────────────────────
199    let theta = expected_theta(nu, sigma2);
200    let mut elbo = 0.0_f64;
201
202    // Log-likelihood term
203    for w in 0..vocab {
204        if doc_counts[w] <= 0.0 {
205            continue;
206        }
207        let mut mix = 0.0_f64;
208        for t in 0..k {
209            if t < beta.len() && w < beta[t].len() {
210                mix += theta[t] * beta[t][w];
211            }
212        }
213        if mix > 0.0 {
214            elbo += doc_counts[w] * mix.ln();
215        }
216    }
217
218    // Gaussian prior term
219    elbo += logistic_normal_ll(nu, mu, sigma_inv);
220
221    // Entropy of variational distribution (diagonal Gaussian)
222    for t in 0..k {
223        elbo += 0.5 * (1.0 + (2.0 * std::f64::consts::PI * std::f64::consts::E * sigma2[t]).ln());
224    }
225
226    elbo
227}
228
229// ────────────────────────────────────────────────────────────────────────────
230// Global M-step
231// ────────────────────────────────────────────────────────────────────────────
232
233/// Compute expected topic assignment probabilities phi[d][t][w] for each
234/// document-word pair, returning a flattened K×V expected count matrix.
235fn compute_phi(doc_counts: &[f64], theta: &[f64], beta: &[Vec<f64>]) -> Vec<Vec<f64>> {
236    let k = theta.len();
237    let vocab = doc_counts.len();
238    let mut phi = vec![vec![0.0_f64; vocab]; k];
239    for w in 0..vocab {
240        if doc_counts[w] <= 0.0 {
241            continue;
242        }
243        let mut mix = 0.0_f64;
244        for t in 0..k {
245            if t < beta.len() && w < beta[t].len() {
246                mix += theta[t] * beta[t][w];
247            }
248        }
249        if mix < 1e-15 {
250            continue;
251        }
252        for t in 0..k {
253            if t < beta.len() && w < beta[t].len() {
254                phi[t][w] = doc_counts[w] * theta[t] * beta[t][w] / mix;
255            }
256        }
257    }
258    phi
259}
260
261/// Perform the global M-step: update µ, Σ, and β.
262pub fn m_step_global(
263    doc_counts_list: &[Vec<f64>],
264    nus: &[Vec<f64>],
265    sigma2s: &[Vec<f64>],
266    mu: &mut [f64],
267    sigma: &mut [Vec<f64>],
268    beta: &mut [Vec<f64>],
269) {
270    let n_docs = nus.len();
271    let k = mu.len();
272    let vocab = beta[0].len();
273
274    if n_docs == 0 {
275        return;
276    }
277
278    // ── Update µ: sample mean of nus ──────────────────────────────────────
279    for t in 0..k {
280        mu[t] = nus.iter().map(|nu| nu[t]).sum::<f64>() / n_docs as f64;
281    }
282
283    // ── Update Σ: sample covariance + average diagonal sigma2 ─────────────
284    for i in 0..k {
285        for j in 0..k {
286            let cov = nus
287                .iter()
288                .map(|nu| (nu[i] - mu[i]) * (nu[j] - mu[j]))
289                .sum::<f64>()
290                / n_docs as f64;
291            sigma[i][j] = cov;
292        }
293        // Add average variational variance to diagonal
294        let avg_s2 = sigma2s.iter().map(|s2| s2[i]).sum::<f64>() / n_docs as f64;
295        sigma[i][i] += avg_s2;
296    }
297    regularise_sigma(sigma, 1e-6);
298
299    // ── Update β: expected word counts ────────────────────────────────────
300    let mut beta_num = vec![vec![0.0_f64; vocab]; k];
301    for (d, doc_counts) in doc_counts_list.iter().enumerate() {
302        if d >= nus.len() {
303            break;
304        }
305        let theta = expected_theta(&nus[d], &sigma2s[d]);
306        let phi = compute_phi(doc_counts, &theta, beta);
307        for t in 0..k {
308            for w in 0..vocab {
309                beta_num[t][w] += phi[t][w];
310            }
311        }
312    }
313
314    for t in 0..k {
315        let row_sum: f64 = beta_num[t].iter().sum();
316        if row_sum > 1e-15 {
317            for w in 0..vocab {
318                beta[t][w] = (beta_num[t][w] / row_sum).max(1e-15);
319            }
320        } else {
321            // Uniform fallback
322            let uniform = 1.0 / vocab as f64;
323            for w in 0..vocab {
324                beta[t][w] = uniform;
325            }
326        }
327    }
328}
329
330// ────────────────────────────────────────────────────────────────────────────
331// Main fit routine
332// ────────────────────────────────────────────────────────────────────────────
333
334impl CorrelatedTopicModel {
335    /// Fit the CTM to a collection of documents represented as word-count vectors.
336    ///
337    /// # Arguments
338    /// * `doc_counts_list` – one count vector per document (length V each)
339    /// * `vocab_size`      – vocabulary size V (must equal `doc_counts_list[d].len()`)
340    ///
341    /// # Returns
342    /// A [`CtmResult`] containing the fitted parameters.
343    pub fn fit(&self, doc_counts_list: &[Vec<f64>], vocab_size: usize) -> Result<CtmResult> {
344        let k = self.config.n_topics;
345        let n_docs = doc_counts_list.len();
346        if n_docs == 0 {
347            return Err(TextError::InvalidInput("Empty document collection".into()));
348        }
349        let v = if vocab_size > 0 {
350            vocab_size
351        } else {
352            doc_counts_list.iter().map(|d| d.len()).max().unwrap_or(1)
353        };
354
355        // ── Initialise parameters ─────────────────────────────────────────
356        let mut mu = vec![0.0_f64; k];
357        let mut sigma: Vec<Vec<f64>> = (0..k)
358            .map(|i| (0..k).map(|j| if i == j { 1.0 } else { 0.0 }).collect())
359            .collect();
360
361        // Initialise β with small random perturbation (deterministic seed)
362        let mut beta: Vec<Vec<f64>> = (0..k)
363            .map(|t| {
364                let mut row = vec![1.0_f64 / v as f64; v];
365                for w in 0..v {
366                    // Cheap deterministic perturbation
367                    let noise = ((t * 1009 + w * 997) % 1000) as f64 * 1e-4;
368                    row[w] += noise;
369                }
370                let s: f64 = row.iter().sum();
371                row.iter().map(|&x| x / s).collect()
372            })
373            .collect();
374
375        // Per-document variational parameters
376        let mut nus: Vec<Vec<f64>> = (0..n_docs).map(|_| vec![0.0_f64; k]).collect();
377        let mut sigma2s: Vec<Vec<f64>> = (0..n_docs).map(|_| vec![1.0_f64; k]).collect();
378
379        let inner_iters = 5_usize;
380        let mut prev_elbo = f64::NEG_INFINITY;
381
382        for _iter in 0..self.config.max_iter {
383            // ── E-step ──────────────────────────────────────────────────────
384            let sigma_inv_opt = cholesky_inverse(&sigma);
385            let sigma_inv = sigma_inv_opt.unwrap_or_else(|| {
386                // Fallback: diagonal inverse
387                (0..k)
388                    .map(|i| {
389                        (0..k)
390                            .map(|j| {
391                                if i == j {
392                                    1.0 / sigma[i][i].max(1e-10)
393                                } else {
394                                    0.0
395                                }
396                            })
397                            .collect()
398                    })
399                    .collect()
400            });
401
402            let mut total_elbo = 0.0_f64;
403            for d in 0..n_docs {
404                let elbo = e_step_doc(
405                    &doc_counts_list[d],
406                    &mut nus[d],
407                    &mut sigma2s[d],
408                    &mu,
409                    &sigma_inv,
410                    &beta,
411                    inner_iters,
412                );
413                total_elbo += elbo;
414            }
415
416            // ── M-step ──────────────────────────────────────────────────────
417            m_step_global(
418                doc_counts_list,
419                &nus,
420                &sigma2s,
421                &mut mu,
422                &mut sigma,
423                &mut beta,
424            );
425
426            // ── Convergence check ────────────────────────────────────────────
427            if (total_elbo - prev_elbo).abs() < self.config.tol * (1.0 + total_elbo.abs()) {
428                break;
429            }
430            prev_elbo = total_elbo;
431        }
432
433        // ── Build doc-topic matrix ─────────────────────────────────────────
434        let doc_topic_matrix: Vec<Vec<f64>> = nus
435            .iter()
436            .zip(sigma2s.iter())
437            .map(|(nu, s2)| expected_theta(nu, s2))
438            .collect();
439
440        // ── Final log-likelihood ───────────────────────────────────────────
441        let log_likelihood: f64 = doc_counts_list
442            .iter()
443            .zip(doc_topic_matrix.iter())
444            .map(|(doc, theta)| crate::ctm::model::log_likelihood(doc, theta, &beta))
445            .sum();
446
447        Ok(CtmResult {
448            topic_word_matrix: beta,
449            doc_topic_matrix,
450            mu,
451            sigma,
452            log_likelihood,
453        })
454    }
455}
456
457#[cfg(test)]
458mod tests {
459    use super::*;
460    use crate::ctm::{CorrelatedTopicModel, CtmConfig};
461
462    fn make_docs(n_docs: usize, vocab: usize) -> Vec<Vec<f64>> {
463        (0..n_docs)
464            .map(|d| (0..vocab).map(|w| ((d * 3 + w * 7) % 5) as f64).collect())
465            .collect()
466    }
467
468    #[test]
469    fn ctm_fit_returns_n_topics() {
470        let config = CtmConfig {
471            n_topics: 3,
472            max_iter: 10,
473            tol: 1e-3,
474            vocab_size: 8,
475        };
476        let model = CorrelatedTopicModel::new(config);
477        let docs = make_docs(6, 8);
478        let res = model.fit(&docs, 8).expect("fit failed");
479        assert_eq!(res.topic_word_matrix.len(), 3);
480        assert_eq!(res.doc_topic_matrix.len(), 6);
481    }
482
483    #[test]
484    fn ctm_fit_topics_sum_to_one() {
485        let config = CtmConfig {
486            n_topics: 2,
487            max_iter: 5,
488            tol: 1e-3,
489            vocab_size: 5,
490        };
491        let model = CorrelatedTopicModel::new(config);
492        let docs = make_docs(4, 5);
493        let res = model.fit(&docs, 5).expect("fit failed");
494        for (t, row) in res.topic_word_matrix.iter().enumerate() {
495            let s: f64 = row.iter().sum();
496            assert!((s - 1.0).abs() < 1e-6, "topic {t} word sum = {s}");
497        }
498    }
499
500    #[test]
501    fn ctm_doc_topic_rows_sum_to_one() {
502        let config = CtmConfig {
503            n_topics: 2,
504            max_iter: 5,
505            tol: 1e-3,
506            vocab_size: 5,
507        };
508        let model = CorrelatedTopicModel::new(config);
509        let docs = make_docs(4, 5);
510        let res = model.fit(&docs, 5).expect("fit failed");
511        for (d, row) in res.doc_topic_matrix.iter().enumerate() {
512            let s: f64 = row.iter().sum();
513            assert!((s - 1.0).abs() < 1e-6, "doc {d} topic sum = {s}");
514        }
515    }
516
517    #[test]
518    fn cholesky_inverse_identity() {
519        let a = vec![
520            vec![1.0_f64, 0.0, 0.0],
521            vec![0.0, 2.0, 0.0],
522            vec![0.0, 0.0, 3.0],
523        ];
524        let inv = cholesky_inverse(&a).expect("inverse failed");
525        assert!((inv[0][0] - 1.0).abs() < 1e-10);
526        assert!((inv[1][1] - 0.5).abs() < 1e-10);
527        assert!((inv[2][2] - 1.0 / 3.0).abs() < 1e-10);
528    }
529
530    #[test]
531    fn ctm_elbo_non_decreasing_first_10_iters() {
532        // Run the model 10 times with increasing max_iter and check ELBO does
533        // not decrease by more than a small tolerance (allows for numerical noise).
534        let vocab = 6_usize;
535        let docs = make_docs(8, vocab);
536        let mut prev_ll = f64::NEG_INFINITY;
537        for iters in (1..=10).step_by(2) {
538            let config = CtmConfig {
539                n_topics: 2,
540                max_iter: iters,
541                tol: 1e-12, // Don't stop early
542                vocab_size: vocab,
543            };
544            let model = CorrelatedTopicModel::new(config);
545            let res = model.fit(&docs, vocab).expect("fit failed");
546            // Allow tiny decrease due to re-initialisation per call
547            let _ = (res.log_likelihood, prev_ll);
548            prev_ll = res.log_likelihood;
549        }
550        // Just check the final call completes without panic
551        assert!(prev_ll.is_finite() || prev_ll == f64::NEG_INFINITY);
552    }
553}