1use crate::error::{Result, TextError};
24use scirs2_core::random::prelude::*;
25use scirs2_core::random::{rngs::StdRng, SeedableRng};
26
27#[derive(Debug, Clone)]
31pub struct HdpConfig {
32 pub max_topics: usize,
35 pub alpha: f64,
37 pub gamma: f64,
39 pub eta: f64,
41 pub n_iter: usize,
43 pub seed: Option<u64>,
45}
46
47impl Default for HdpConfig {
48 fn default() -> Self {
49 HdpConfig {
50 max_topics: 20,
51 alpha: 1.0,
52 gamma: 1.0,
53 eta: 0.1,
54 n_iter: 100,
55 seed: None,
56 }
57 }
58}
59
60#[derive(Debug, Clone)]
64pub struct HdpResult {
65 pub n_topics: usize,
67 pub perplexity: f64,
69 pub log_likelihood: f64,
71 pub iterations: usize,
73}
74
75pub struct HdpModel {
89 config: HdpConfig,
90 phi: Vec<Vec<f64>>,
92 topic_word_counts: Vec<Vec<f64>>,
94 topic_counts: Vec<usize>,
96 pub n_topics_active: usize,
98 vocab_size: usize,
100 is_fitted: bool,
102}
103
104impl HdpModel {
105 pub fn new(config: HdpConfig) -> Self {
107 let t = config.max_topics;
108 HdpModel {
109 config,
110 phi: vec![vec![]; t],
111 topic_word_counts: vec![vec![]; t],
112 topic_counts: vec![0; t],
113 n_topics_active: 0,
114 vocab_size: 0,
115 is_fitted: false,
116 }
117 }
118
119 pub fn fit(&mut self, corpus: &[Vec<usize>], vocab_size: usize) -> Result<HdpResult> {
130 if corpus.is_empty() {
131 return Err(TextError::InvalidInput(
132 "corpus must not be empty".to_string(),
133 ));
134 }
135 if vocab_size == 0 {
136 return Err(TextError::InvalidInput(
137 "vocab_size must be > 0".to_string(),
138 ));
139 }
140 for (di, doc) in corpus.iter().enumerate() {
142 for &w in doc {
143 if w >= vocab_size {
144 return Err(TextError::InvalidInput(format!(
145 "word index {w} in document {di} exceeds vocab_size {vocab_size}"
146 )));
147 }
148 }
149 }
150
151 self.vocab_size = vocab_size;
152 let t = self.config.max_topics;
153
154 self.topic_word_counts = vec![vec![0.0f64; vocab_size]; t];
156 self.topic_counts = vec![0usize; t];
157
158 let mut rng = self.make_rng();
159
160 let n_docs = corpus.len();
162 let mut z: Vec<Vec<usize>> = corpus
164 .iter()
165 .map(|doc| {
166 doc.iter()
167 .map(|_| rng.random_range(0..t))
168 .collect::<Vec<usize>>()
169 })
170 .collect();
171
172 let mut theta_counts: Vec<Vec<usize>> = vec![vec![0usize; t]; n_docs];
174
175 for (d, doc) in corpus.iter().enumerate() {
177 for (n, &w) in doc.iter().enumerate() {
178 let k = z[d][n];
179 self.topic_word_counts[k][w] += 1.0;
180 self.topic_counts[k] += 1;
181 theta_counts[d][k] += 1;
182 }
183 }
184
185 let alpha = self.config.alpha;
187 let eta = self.config.eta;
188 let eta_sum = eta * vocab_size as f64;
189
190 let mut iter_done = 0usize;
191 for _iter in 0..self.config.n_iter {
192 for d in 0..n_docs {
193 for n in 0..corpus[d].len() {
194 let w = corpus[d][n];
195 let k_old = z[d][n];
196
197 self.topic_word_counts[k_old][w] -= 1.0;
199 self.topic_counts[k_old] -= 1;
200 theta_counts[d][k_old] -= 1;
201
202 let mut probs = vec![0.0f64; t];
204 for k in 0..t {
205 let doc_factor = theta_counts[d][k] as f64 + alpha / t as f64;
206 let word_factor = (self.topic_word_counts[k][w] + eta)
207 / (self.topic_counts[k] as f64 + eta_sum);
208 probs[k] = doc_factor * word_factor;
209 }
210
211 let k_new = sample_categorical(&probs, &mut rng);
213
214 z[d][n] = k_new;
216 self.topic_word_counts[k_new][w] += 1.0;
217 self.topic_counts[k_new] += 1;
218 theta_counts[d][k_new] += 1;
219 }
220 }
221 iter_done += 1;
222 }
223
224 self.phi = (0..t)
226 .map(|k| {
227 let total = self.topic_counts[k] as f64 + eta_sum;
228 (0..vocab_size)
229 .map(|w| (self.topic_word_counts[k][w] + eta) / total)
230 .collect()
231 })
232 .collect();
233
234 self.n_topics_active = self.topic_counts.iter().filter(|&&c| c > 0).count();
235 self.is_fitted = true;
236
237 let (ll, pp) = self.compute_perplexity(corpus, &theta_counts, eta, eta_sum);
239
240 Ok(HdpResult {
241 n_topics: self.n_topics_active,
242 perplexity: pp,
243 log_likelihood: ll,
244 iterations: iter_done,
245 })
246 }
247
248 pub fn transform(&self, doc: &[usize]) -> Result<Vec<f64>> {
258 if !self.is_fitted {
259 return Err(TextError::ModelNotFitted(
260 "HDP model not fitted yet".to_string(),
261 ));
262 }
263 if doc.is_empty() {
264 return Err(TextError::InvalidInput(
265 "document must not be empty".to_string(),
266 ));
267 }
268
269 let t = self.config.max_topics;
270 let eta = self.config.eta;
271 let eta_sum = eta * self.vocab_size as f64;
272
273 let mut theta = vec![self.config.alpha / t as f64; t];
274
275 for &w in doc {
277 if w >= self.vocab_size {
278 continue;
279 }
280 let mut word_probs: Vec<f64> = (0..t)
281 .map(|k| {
282 theta[k] * (self.topic_word_counts[k][w] + eta)
283 / (self.topic_counts[k] as f64 + eta_sum)
284 })
285 .collect();
286 let sum: f64 = word_probs.iter().sum();
287 if sum > 0.0 {
288 word_probs.iter_mut().for_each(|p| *p /= sum);
289 for k in 0..t {
290 theta[k] += word_probs[k];
291 }
292 }
293 }
294
295 let theta_sum: f64 = theta.iter().sum();
297 if theta_sum > 0.0 {
298 theta.iter_mut().for_each(|p| *p /= theta_sum);
299 }
300
301 Ok(theta)
302 }
303
304 pub fn top_words(&self, n: usize) -> Result<Vec<Vec<usize>>> {
312 if !self.is_fitted {
313 return Err(TextError::ModelNotFitted(
314 "HDP model not fitted yet".to_string(),
315 ));
316 }
317
318 let t = self.config.max_topics;
319 let mut result = Vec::new();
320
321 for k in 0..t {
322 if self.topic_counts[k] == 0 {
323 continue; }
325 let phi_k = &self.phi[k];
326 let mut indices: Vec<usize> = (0..phi_k.len()).collect();
327 indices.sort_by(|&a, &b| {
329 phi_k[b]
330 .partial_cmp(&phi_k[a])
331 .unwrap_or(std::cmp::Ordering::Equal)
332 });
333 indices.truncate(n);
334 result.push(indices);
335 }
336
337 Ok(result)
338 }
339
340 pub fn coherence(&self, corpus: &[Vec<usize>], n_top: usize) -> Result<Vec<f64>> {
350 if !self.is_fitted {
351 return Err(TextError::ModelNotFitted(
352 "HDP model not fitted yet".to_string(),
353 ));
354 }
355
356 let top = self.top_words(n_top)?;
357 let n_docs = corpus.len() as f64;
358
359 let mut df: Vec<f64> = vec![0.0; self.vocab_size];
361 let mut codf: Vec<Vec<f64>> = vec![vec![0.0; self.vocab_size]; self.vocab_size];
362 for doc in corpus {
363 let mut seen = std::collections::HashSet::new();
365 for &w in doc {
366 if w < self.vocab_size && seen.insert(w) {
367 df[w] += 1.0;
368 }
369 }
370 let seen_vec: Vec<usize> = seen.into_iter().collect();
371 for (i, &wi) in seen_vec.iter().enumerate() {
372 for &wj in &seen_vec[i + 1..] {
373 let (a, b) = if wi < wj { (wi, wj) } else { (wj, wi) };
374 codf[a][b] += 1.0;
375 }
376 }
377 }
378
379 let mut scores = Vec::with_capacity(top.len());
380 for topic_words in &top {
381 let mut sum = 0.0f64;
382 let mut count = 0usize;
383 for (i, &wi) in topic_words.iter().enumerate() {
384 for &wj in &topic_words[i + 1..] {
385 let (a, b) = if wi < wj { (wi, wj) } else { (wj, wi) };
386 let co = codf[a][b] + 1.0; let di = df[wi] + 1.0;
388 let dj = df[wj] + 1.0;
389 let pmi = (co / n_docs).ln() - (di / n_docs).ln() - (dj / n_docs).ln();
391 sum += pmi;
392 count += 1;
393 }
394 }
395 scores.push(if count > 0 { sum / count as f64 } else { 0.0 });
396 }
397
398 Ok(scores)
399 }
400
401 fn make_rng(&self) -> StdRng {
404 match self.config.seed {
405 Some(s) => StdRng::seed_from_u64(s),
406 None => StdRng::from_rng(&mut scirs2_core::random::rng()),
407 }
408 }
409
410 fn compute_perplexity(
412 &self,
413 corpus: &[Vec<usize>],
414 theta_counts: &[Vec<usize>],
415 eta: f64,
416 eta_sum: f64,
417 ) -> (f64, f64) {
418 let t = self.config.max_topics;
419 let alpha = self.config.alpha;
420 let mut total_ll = 0.0f64;
421 let mut total_tokens = 0usize;
422
423 for (d, doc) in corpus.iter().enumerate() {
424 let theta_sum: f64 = theta_counts[d].iter().sum::<usize>() as f64 + alpha;
425 for &w in doc {
426 if w >= self.vocab_size {
427 continue;
428 }
429 let p_w: f64 = (0..t)
431 .map(|k| {
432 let theta_dk = (theta_counts[d][k] as f64 + alpha / t as f64) / theta_sum;
433 let phi_kw = (self.topic_word_counts[k][w] + eta)
434 / (self.topic_counts[k] as f64 + eta_sum);
435 theta_dk * phi_kw
436 })
437 .sum();
438 if p_w > 0.0 {
439 total_ll += p_w.ln();
440 }
441 total_tokens += 1;
442 }
443 }
444
445 if total_tokens == 0 {
446 return (0.0, 1.0);
447 }
448
449 let avg_ll = total_ll / total_tokens as f64;
450 let perplexity = (-avg_ll).exp();
451 (avg_ll, perplexity)
452 }
453}
454
455impl std::fmt::Debug for HdpModel {
456 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
457 f.debug_struct("HdpModel")
458 .field("max_topics", &self.config.max_topics)
459 .field("n_topics_active", &self.n_topics_active)
460 .field("vocab_size", &self.vocab_size)
461 .field("is_fitted", &self.is_fitted)
462 .finish()
463 }
464}
465
466fn sample_categorical(probs: &[f64], rng: &mut StdRng) -> usize {
470 let total: f64 = probs.iter().sum();
471 if total <= 0.0 {
472 return rng.random_range(0..probs.len());
474 }
475 let u: f64 = rng.random_range(0.0..total);
476 let mut cumulative = 0.0;
477 for (i, &p) in probs.iter().enumerate() {
478 cumulative += p;
479 if u < cumulative {
480 return i;
481 }
482 }
483 probs.len() - 1 }
485
486#[cfg(test)]
489mod tests {
490 use super::*;
491
492 fn synthetic_corpus(n_per_topic: usize) -> Vec<Vec<usize>> {
497 let mut corpus = Vec::new();
498 let mut rng = StdRng::seed_from_u64(99);
499 for _ in 0..n_per_topic {
501 let doc: Vec<usize> = (0..20).map(|_| rng.random_range(0..5)).collect();
502 corpus.push(doc);
503 }
504 for _ in 0..n_per_topic {
506 let doc: Vec<usize> = (0..20).map(|_| rng.random_range(5..10)).collect();
507 corpus.push(doc);
508 }
509 for _ in 0..n_per_topic {
511 let doc: Vec<usize> = (0..20).map(|_| rng.random_range(10..15)).collect();
512 corpus.push(doc);
513 }
514 corpus
515 }
516
517 #[test]
520 fn test_hdp_infers_topics() {
521 let corpus = synthetic_corpus(15);
522 let config = HdpConfig {
523 max_topics: 20,
524 n_iter: 30,
525 seed: Some(42),
526 ..Default::default()
527 };
528 let mut model = HdpModel::new(config);
529 let result = model.fit(&corpus, 15).expect("fit should succeed");
530
531 assert!(
532 result.n_topics <= 20,
533 "active topics ({}) must be <= max_topics",
534 result.n_topics
535 );
536 assert!(
537 result.n_topics >= 1,
538 "at least one topic must be active, got {}",
539 result.n_topics
540 );
541 }
542
543 #[test]
546 fn test_hdp_perplexity_finite() {
547 let corpus = synthetic_corpus(10);
548 let config = HdpConfig {
549 max_topics: 10,
550 n_iter: 20,
551 seed: Some(7),
552 ..Default::default()
553 };
554 let mut model = HdpModel::new(config);
555 let result = model.fit(&corpus, 15).expect("fit should succeed");
556
557 assert!(
558 result.perplexity.is_finite(),
559 "perplexity must be finite, got {}",
560 result.perplexity
561 );
562 assert!(
563 result.perplexity > 0.0,
564 "perplexity must be positive, got {}",
565 result.perplexity
566 );
567 assert!(
568 result.log_likelihood.is_finite(),
569 "log_likelihood must be finite"
570 );
571 }
572
573 #[test]
576 fn test_hdp_top_words_valid() {
577 let corpus = synthetic_corpus(10);
578 let config = HdpConfig {
579 max_topics: 10,
580 n_iter: 20,
581 seed: Some(1),
582 ..Default::default()
583 };
584 let mut model = HdpModel::new(config);
585 model.fit(&corpus, 15).expect("fit should succeed");
586
587 let top5 = model.top_words(5).expect("top_words should succeed");
588 for topic_words in &top5 {
589 assert!(
590 topic_words.len() <= 5,
591 "each topic should have <= n top words"
592 );
593 for &w in topic_words {
594 assert!(w < 15, "word index {w} must be < vocab_size 15");
595 }
596 }
597 }
598
599 #[test]
600 fn test_hdp_transform() {
601 let corpus = synthetic_corpus(10);
602 let config = HdpConfig {
603 max_topics: 5,
604 n_iter: 15,
605 seed: Some(123),
606 ..Default::default()
607 };
608 let mut model = HdpModel::new(config);
609 model.fit(&corpus, 15).expect("fit should succeed");
610
611 let doc = vec![0usize, 1, 2, 3, 0];
612 let theta = model.transform(&doc).expect("transform should succeed");
613 assert_eq!(theta.len(), 5);
614 let sum: f64 = theta.iter().sum();
615 assert!((sum - 1.0).abs() < 1e-9, "topic distribution must sum to 1");
616 for &p in &theta {
617 assert!(p >= 0.0, "all topic probabilities must be >= 0");
618 }
619 }
620
621 #[test]
622 fn test_hdp_coherence() {
623 let corpus = synthetic_corpus(10);
624 let config = HdpConfig {
625 max_topics: 5,
626 n_iter: 15,
627 seed: Some(55),
628 ..Default::default()
629 };
630 let mut model = HdpModel::new(config);
631 model.fit(&corpus, 15).expect("fit should succeed");
632
633 let scores = model
634 .coherence(&corpus, 3)
635 .expect("coherence should succeed");
636 for &s in &scores {
637 assert!(s.is_finite(), "coherence score must be finite, got {s}");
638 }
639 }
640
641 #[test]
642 fn test_hdp_empty_corpus_error() {
643 let mut model = HdpModel::new(HdpConfig::default());
644 let result = model.fit(&[], 10);
645 assert!(result.is_err());
646 }
647
648 #[test]
649 fn test_hdp_zero_vocab_error() {
650 let mut model = HdpModel::new(HdpConfig::default());
651 let result = model.fit(&[vec![0usize, 1]], 0);
652 assert!(result.is_err());
653 }
654}