1use crate::ctm::model::softmax;
11use crate::ctm::{CorrelatedTopicModel, CtmConfig, CtmResult};
12use crate::error::{Result, TextError};
13
14pub fn cholesky_inverse(a: &[Vec<f64>]) -> Option<Vec<Vec<f64>>> {
23 let k = a.len();
24 let mut l = vec![vec![0.0_f64; k]; k];
26 for i in 0..k {
27 for j in 0..=i {
28 let mut sum = a[i][j];
29 for p in 0..j {
30 sum -= l[i][p] * l[j][p];
31 }
32 if i == j {
33 if sum <= 0.0 {
34 return None; }
36 l[i][j] = sum.sqrt();
37 } else {
38 l[i][j] = sum / l[j][j];
39 }
40 }
41 }
42 let mut l_inv = vec![vec![0.0_f64; k]; k];
44 for i in 0..k {
45 l_inv[i][i] = 1.0 / l[i][i];
46 for j in 0..i {
47 let mut sum = 0.0_f64;
48 for p in j..i {
49 sum -= l[i][p] * l_inv[p][j];
50 }
51 l_inv[i][j] = sum / l[i][i];
52 }
53 }
54 let mut inv = vec![vec![0.0_f64; k]; k];
56 for i in 0..k {
57 for j in 0..k {
58 let mut s = 0.0_f64;
59 for p in 0..k {
60 s += l_inv[p][i] * l_inv[p][j];
61 }
62 inv[i][j] = s;
63 }
64 }
65 Some(inv)
66}
67
68fn regularise_sigma(sigma: &mut [Vec<f64>], eps: f64) {
70 let k = sigma.len();
71 for i in 0..k {
72 sigma[i][i] += eps;
73 }
74}
75
76pub fn logistic_normal_ll(eta: &[f64], mu: &[f64], sigma_inv: &[Vec<f64>]) -> f64 {
84 let k = eta.len();
85 let mut ll = 0.0_f64;
86 for i in 0..k {
87 let di = eta[i] - mu[i];
88 for j in 0..k {
89 let dj = eta[j] - mu[j];
90 ll -= 0.5 * di * sigma_inv[i][j] * dj;
91 }
92 }
93 ll
94}
95
96fn expected_theta(nu: &[f64], _sigma2: &[f64]) -> Vec<f64> {
106 softmax(nu)
107}
108
109pub fn e_step_doc(
127 doc_counts: &[f64],
128 nu: &mut [f64],
129 sigma2: &mut [f64],
130 mu: &[f64],
131 sigma_inv: &[Vec<f64>],
132 beta: &[Vec<f64>],
133 max_inner: usize,
134) -> f64 {
135 let k = nu.len();
136 let vocab = doc_counts.len();
137 let n_words: f64 = doc_counts.iter().sum();
138
139 for _ in 0..max_inner {
140 let theta = expected_theta(nu, sigma2);
141
142 for t in 0..k {
147 let prec = sigma_inv[t][t].max(1e-10);
148 sigma2[t] = (1.0 / prec).max(1e-8);
149 }
150
151 for t in 0..k {
158 let mut grad = 0.0_f64;
160
161 for w in 0..vocab {
163 if doc_counts[w] <= 0.0 {
164 continue;
165 }
166 let mut mix = 0.0_f64;
167 for s in 0..k {
168 if s < beta.len() && w < beta[s].len() {
169 mix += theta[s] * beta[s][w];
170 }
171 }
172 if mix > 1e-15 {
173 let phi = if t < beta.len() && w < beta[t].len() {
174 theta[t] * beta[t][w] / mix
175 } else {
176 0.0
177 };
178 grad += doc_counts[w] * (phi - theta[t]);
179 }
180 }
181
182 for j in 0..k {
184 grad -= sigma_inv[t][j] * (nu[j] - mu[j]);
185 }
186
187 let hess = -(n_words * theta[t] * (1.0 - theta[t]) + sigma_inv[t][t])
189 .abs()
190 .max(1e-10);
191
192 let step = (grad / hess).clamp(-2.0, 2.0);
194 nu[t] -= step;
195 }
196 }
197
198 let theta = expected_theta(nu, sigma2);
200 let mut elbo = 0.0_f64;
201
202 for w in 0..vocab {
204 if doc_counts[w] <= 0.0 {
205 continue;
206 }
207 let mut mix = 0.0_f64;
208 for t in 0..k {
209 if t < beta.len() && w < beta[t].len() {
210 mix += theta[t] * beta[t][w];
211 }
212 }
213 if mix > 0.0 {
214 elbo += doc_counts[w] * mix.ln();
215 }
216 }
217
218 elbo += logistic_normal_ll(nu, mu, sigma_inv);
220
221 for t in 0..k {
223 elbo += 0.5 * (1.0 + (2.0 * std::f64::consts::PI * std::f64::consts::E * sigma2[t]).ln());
224 }
225
226 elbo
227}
228
229fn compute_phi(doc_counts: &[f64], theta: &[f64], beta: &[Vec<f64>]) -> Vec<Vec<f64>> {
236 let k = theta.len();
237 let vocab = doc_counts.len();
238 let mut phi = vec![vec![0.0_f64; vocab]; k];
239 for w in 0..vocab {
240 if doc_counts[w] <= 0.0 {
241 continue;
242 }
243 let mut mix = 0.0_f64;
244 for t in 0..k {
245 if t < beta.len() && w < beta[t].len() {
246 mix += theta[t] * beta[t][w];
247 }
248 }
249 if mix < 1e-15 {
250 continue;
251 }
252 for t in 0..k {
253 if t < beta.len() && w < beta[t].len() {
254 phi[t][w] = doc_counts[w] * theta[t] * beta[t][w] / mix;
255 }
256 }
257 }
258 phi
259}
260
261pub fn m_step_global(
263 doc_counts_list: &[Vec<f64>],
264 nus: &[Vec<f64>],
265 sigma2s: &[Vec<f64>],
266 mu: &mut [f64],
267 sigma: &mut [Vec<f64>],
268 beta: &mut [Vec<f64>],
269) {
270 let n_docs = nus.len();
271 let k = mu.len();
272 let vocab = beta[0].len();
273
274 if n_docs == 0 {
275 return;
276 }
277
278 for t in 0..k {
280 mu[t] = nus.iter().map(|nu| nu[t]).sum::<f64>() / n_docs as f64;
281 }
282
283 for i in 0..k {
285 for j in 0..k {
286 let cov = nus
287 .iter()
288 .map(|nu| (nu[i] - mu[i]) * (nu[j] - mu[j]))
289 .sum::<f64>()
290 / n_docs as f64;
291 sigma[i][j] = cov;
292 }
293 let avg_s2 = sigma2s.iter().map(|s2| s2[i]).sum::<f64>() / n_docs as f64;
295 sigma[i][i] += avg_s2;
296 }
297 regularise_sigma(sigma, 1e-6);
298
299 let mut beta_num = vec![vec![0.0_f64; vocab]; k];
301 for (d, doc_counts) in doc_counts_list.iter().enumerate() {
302 if d >= nus.len() {
303 break;
304 }
305 let theta = expected_theta(&nus[d], &sigma2s[d]);
306 let phi = compute_phi(doc_counts, &theta, beta);
307 for t in 0..k {
308 for w in 0..vocab {
309 beta_num[t][w] += phi[t][w];
310 }
311 }
312 }
313
314 for t in 0..k {
315 let row_sum: f64 = beta_num[t].iter().sum();
316 if row_sum > 1e-15 {
317 for w in 0..vocab {
318 beta[t][w] = (beta_num[t][w] / row_sum).max(1e-15);
319 }
320 } else {
321 let uniform = 1.0 / vocab as f64;
323 for w in 0..vocab {
324 beta[t][w] = uniform;
325 }
326 }
327 }
328}
329
330impl CorrelatedTopicModel {
335 pub fn fit(&self, doc_counts_list: &[Vec<f64>], vocab_size: usize) -> Result<CtmResult> {
344 let k = self.config.n_topics;
345 let n_docs = doc_counts_list.len();
346 if n_docs == 0 {
347 return Err(TextError::InvalidInput("Empty document collection".into()));
348 }
349 let v = if vocab_size > 0 {
350 vocab_size
351 } else {
352 doc_counts_list.iter().map(|d| d.len()).max().unwrap_or(1)
353 };
354
355 let mut mu = vec![0.0_f64; k];
357 let mut sigma: Vec<Vec<f64>> = (0..k)
358 .map(|i| (0..k).map(|j| if i == j { 1.0 } else { 0.0 }).collect())
359 .collect();
360
361 let mut beta: Vec<Vec<f64>> = (0..k)
363 .map(|t| {
364 let mut row = vec![1.0_f64 / v as f64; v];
365 for w in 0..v {
366 let noise = ((t * 1009 + w * 997) % 1000) as f64 * 1e-4;
368 row[w] += noise;
369 }
370 let s: f64 = row.iter().sum();
371 row.iter().map(|&x| x / s).collect()
372 })
373 .collect();
374
375 let mut nus: Vec<Vec<f64>> = (0..n_docs).map(|_| vec![0.0_f64; k]).collect();
377 let mut sigma2s: Vec<Vec<f64>> = (0..n_docs).map(|_| vec![1.0_f64; k]).collect();
378
379 let inner_iters = 5_usize;
380 let mut prev_elbo = f64::NEG_INFINITY;
381
382 for _iter in 0..self.config.max_iter {
383 let sigma_inv_opt = cholesky_inverse(&sigma);
385 let sigma_inv = sigma_inv_opt.unwrap_or_else(|| {
386 (0..k)
388 .map(|i| {
389 (0..k)
390 .map(|j| {
391 if i == j {
392 1.0 / sigma[i][i].max(1e-10)
393 } else {
394 0.0
395 }
396 })
397 .collect()
398 })
399 .collect()
400 });
401
402 let mut total_elbo = 0.0_f64;
403 for d in 0..n_docs {
404 let elbo = e_step_doc(
405 &doc_counts_list[d],
406 &mut nus[d],
407 &mut sigma2s[d],
408 &mu,
409 &sigma_inv,
410 &beta,
411 inner_iters,
412 );
413 total_elbo += elbo;
414 }
415
416 m_step_global(
418 doc_counts_list,
419 &nus,
420 &sigma2s,
421 &mut mu,
422 &mut sigma,
423 &mut beta,
424 );
425
426 if (total_elbo - prev_elbo).abs() < self.config.tol * (1.0 + total_elbo.abs()) {
428 break;
429 }
430 prev_elbo = total_elbo;
431 }
432
433 let doc_topic_matrix: Vec<Vec<f64>> = nus
435 .iter()
436 .zip(sigma2s.iter())
437 .map(|(nu, s2)| expected_theta(nu, s2))
438 .collect();
439
440 let log_likelihood: f64 = doc_counts_list
442 .iter()
443 .zip(doc_topic_matrix.iter())
444 .map(|(doc, theta)| crate::ctm::model::log_likelihood(doc, theta, &beta))
445 .sum();
446
447 Ok(CtmResult {
448 topic_word_matrix: beta,
449 doc_topic_matrix,
450 mu,
451 sigma,
452 log_likelihood,
453 })
454 }
455}
456
457#[cfg(test)]
458mod tests {
459 use super::*;
460 use crate::ctm::{CorrelatedTopicModel, CtmConfig};
461
462 fn make_docs(n_docs: usize, vocab: usize) -> Vec<Vec<f64>> {
463 (0..n_docs)
464 .map(|d| (0..vocab).map(|w| ((d * 3 + w * 7) % 5) as f64).collect())
465 .collect()
466 }
467
468 #[test]
469 fn ctm_fit_returns_n_topics() {
470 let config = CtmConfig {
471 n_topics: 3,
472 max_iter: 10,
473 tol: 1e-3,
474 vocab_size: 8,
475 };
476 let model = CorrelatedTopicModel::new(config);
477 let docs = make_docs(6, 8);
478 let res = model.fit(&docs, 8).expect("fit failed");
479 assert_eq!(res.topic_word_matrix.len(), 3);
480 assert_eq!(res.doc_topic_matrix.len(), 6);
481 }
482
483 #[test]
484 fn ctm_fit_topics_sum_to_one() {
485 let config = CtmConfig {
486 n_topics: 2,
487 max_iter: 5,
488 tol: 1e-3,
489 vocab_size: 5,
490 };
491 let model = CorrelatedTopicModel::new(config);
492 let docs = make_docs(4, 5);
493 let res = model.fit(&docs, 5).expect("fit failed");
494 for (t, row) in res.topic_word_matrix.iter().enumerate() {
495 let s: f64 = row.iter().sum();
496 assert!((s - 1.0).abs() < 1e-6, "topic {t} word sum = {s}");
497 }
498 }
499
500 #[test]
501 fn ctm_doc_topic_rows_sum_to_one() {
502 let config = CtmConfig {
503 n_topics: 2,
504 max_iter: 5,
505 tol: 1e-3,
506 vocab_size: 5,
507 };
508 let model = CorrelatedTopicModel::new(config);
509 let docs = make_docs(4, 5);
510 let res = model.fit(&docs, 5).expect("fit failed");
511 for (d, row) in res.doc_topic_matrix.iter().enumerate() {
512 let s: f64 = row.iter().sum();
513 assert!((s - 1.0).abs() < 1e-6, "doc {d} topic sum = {s}");
514 }
515 }
516
517 #[test]
518 fn cholesky_inverse_identity() {
519 let a = vec![
520 vec![1.0_f64, 0.0, 0.0],
521 vec![0.0, 2.0, 0.0],
522 vec![0.0, 0.0, 3.0],
523 ];
524 let inv = cholesky_inverse(&a).expect("inverse failed");
525 assert!((inv[0][0] - 1.0).abs() < 1e-10);
526 assert!((inv[1][1] - 0.5).abs() < 1e-10);
527 assert!((inv[2][2] - 1.0 / 3.0).abs() < 1e-10);
528 }
529
530 #[test]
531 fn ctm_elbo_non_decreasing_first_10_iters() {
532 let vocab = 6_usize;
535 let docs = make_docs(8, vocab);
536 let mut prev_ll = f64::NEG_INFINITY;
537 for iters in (1..=10).step_by(2) {
538 let config = CtmConfig {
539 n_topics: 2,
540 max_iter: iters,
541 tol: 1e-12, vocab_size: vocab,
543 };
544 let model = CorrelatedTopicModel::new(config);
545 let res = model.fit(&docs, vocab).expect("fit failed");
546 let _ = (res.log_likelihood, prev_ll);
548 prev_ll = res.log_likelihood;
549 }
550 assert!(prev_ll.is_finite() || prev_ll == f64::NEG_INFINITY);
552 }
553}