1use crate::error::Result;
7use crate::topic_modeling::Topic;
8use scirs2_core::ndarray::Array2;
9use std::collections::{HashMap, HashSet};
10
11pub struct TopicCoherence {
13 window_size: usize,
15 _min_count: usize,
17 epsilon: f64,
19}
20
21impl Default for TopicCoherence {
22 fn default() -> Self {
23 Self {
24 window_size: 10,
25 _min_count: 5, epsilon: 1e-12,
27 }
28 }
29}
30
31type DocFreqMap = HashMap<String, usize>;
33type CoDocFreqMap = HashMap<(String, String), usize>;
35
36impl TopicCoherence {
37 pub fn new() -> Self {
39 Self::default()
40 }
41
42 pub fn with_window_size(mut self, windowsize: usize) -> Self {
44 self.window_size = windowsize;
45 self
46 }
47
48 pub fn cv_coherence(&self, topics: &[Topic], documents: &[Vec<String>]) -> Result<f64> {
50 let top_words_per_topic: Vec<Vec<String>> = topics
52 .iter()
53 .map(|topic| {
54 topic
55 .top_words
56 .iter()
57 .map(|(word_, _)| word_.clone())
58 .collect()
59 })
60 .collect();
61
62 let (doc_freq, co_doc_freq) =
64 self.calculate_document_frequencies(&top_words_per_topic, documents)?;
65
66 let mut coherence_scores = Vec::new();
68
69 for topic_word_s in &top_words_per_topic {
70 let topic_coherence = self.calculate_topic_coherence_cv(
71 topic_word_s,
72 &doc_freq,
73 &co_doc_freq,
74 documents.len(),
75 )?;
76 coherence_scores.push(topic_coherence);
77 }
78
79 let avg_coherence = coherence_scores.iter().sum::<f64>() / coherence_scores.len() as f64;
81 Ok(avg_coherence)
82 }
83
84 pub fn umass_coherence(&self, topics: &[Topic], documents: &[Vec<String>]) -> Result<f64> {
86 let doc_sets: Vec<HashSet<String>> = documents
88 .iter()
89 .map(|doc| doc.iter().cloned().collect())
90 .collect();
91
92 let mut coherence_scores = Vec::new();
93
94 for topic in topics {
95 let top_words: Vec<&String> = topic.top_words.iter().map(|(word_, _)| word_).collect();
96
97 let topic_coherence = self.calculate_topic_coherence_umass(&top_words, &doc_sets)?;
98 coherence_scores.push(topic_coherence);
99 }
100
101 let avg_coherence = coherence_scores.iter().sum::<f64>() / coherence_scores.len() as f64;
102 Ok(avg_coherence)
103 }
104
105 pub fn uci_coherence(&self, topics: &[Topic], documents: &[Vec<String>]) -> Result<f64> {
107 let (word_freq, co_occurrence) = self.build_co_occurrence_matrix(documents)?;
109
110 let mut coherence_scores = Vec::new();
111
112 for topic in topics {
113 let top_words: Vec<&String> = topic.top_words.iter().map(|(word_, _)| word_).collect();
114
115 let topic_coherence =
116 self.calculate_topic_coherence_uci(&top_words, &word_freq, &co_occurrence)?;
117 coherence_scores.push(topic_coherence);
118 }
119
120 let avg_coherence = coherence_scores.iter().sum::<f64>() / coherence_scores.len() as f64;
121 Ok(avg_coherence)
122 }
123
124 fn calculate_document_frequencies(
126 &self,
127 topics: &[Vec<String>],
128 documents: &[Vec<String>],
129 ) -> Result<(DocFreqMap, CoDocFreqMap)> {
130 let mut doc_freq: HashMap<String, usize> = HashMap::new();
131 let mut co_doc_freq: HashMap<(String, String), usize> = HashMap::new();
132
133 let mut all_word_s: HashSet<String> = HashSet::new();
135 for topic in topics {
136 for word_ in topic {
137 all_word_s.insert(word_.clone());
138 }
139 }
140
141 for doc in documents {
143 let doc_set: HashSet<String> = doc.iter().cloned().collect();
144
145 for word_ in &all_word_s {
147 if doc_set.contains(word_) {
148 *doc_freq.entry(word_.clone()).or_insert(0) += 1;
149 }
150 }
151
152 let word_s_vec: Vec<&String> = all_word_s.iter().collect();
154 for i in 0..word_s_vec.len() {
155 for j in (i + 1)..word_s_vec.len() {
156 let word_1 = word_s_vec[i];
157 let word_2 = word_s_vec[j];
158
159 if doc_set.contains(word_1) && doc_set.contains(word_2) {
160 let key = if word_1 < word_2 {
161 (word_1.clone(), word_2.clone())
162 } else {
163 (word_2.clone(), word_1.clone())
164 };
165 *co_doc_freq.entry(key).or_insert(0) += 1;
166 }
167 }
168 }
169 }
170
171 Ok((doc_freq, co_doc_freq))
172 }
173
174 fn calculate_topic_coherence_cv(
176 &self,
177 topic_word_s: &[String],
178 doc_freq: &HashMap<String, usize>,
179 co_doc_freq: &HashMap<(String, String), usize>,
180 n_docs: usize,
181 ) -> Result<f64> {
182 let mut scores = Vec::new();
183
184 for i in 0..topic_word_s.len() {
185 for j in (i + 1)..topic_word_s.len() {
186 let word_1 = &topic_word_s[i];
187 let word_2 = &topic_word_s[j];
188
189 let freq1 = doc_freq.get(word_1).copied().unwrap_or(0) as f64;
190 let freq2 = doc_freq.get(word_2).copied().unwrap_or(0) as f64;
191
192 let co_freq = co_doc_freq
193 .get(&if word_1 < word_2 {
194 (word_1.clone(), word_2.clone())
195 } else {
196 (word_2.clone(), word_1.clone())
197 })
198 .copied()
199 .unwrap_or(0) as f64;
200
201 let npmi = self.calculate_npmi(freq1, freq2, co_freq, n_docs as f64);
203 scores.push(npmi);
204 }
205 }
206
207 if scores.is_empty() {
208 Ok(0.0)
209 } else {
210 Ok(scores.iter().sum::<f64>() / scores.len() as f64)
211 }
212 }
213
214 fn calculate_topic_coherence_umass(
216 &self,
217 topic_word_s: &[&String],
218 doc_sets: &[HashSet<String>],
219 ) -> Result<f64> {
220 let mut scores = Vec::new();
221
222 for i in 1..topic_word_s.len() {
223 for j in 0..i {
224 let word_i = topic_word_s[i];
225 let word_j = topic_word_s[j];
226
227 let mut count_j = 0;
228 let mut count_both = 0;
229
230 for doc_set in doc_sets {
231 let has_i = doc_set.contains(word_i);
232 let has_j = doc_set.contains(word_j);
233
234 if has_j {
235 count_j += 1;
236 }
237 if has_i && has_j {
238 count_both += 1;
239 }
240 }
241
242 let score = if count_both > 0 {
244 ((count_both as f64 + self.epsilon) / count_j as f64).ln()
245 } else {
246 (self.epsilon / count_j.max(1) as f64).ln()
247 };
248
249 scores.push(score);
250 }
251 }
252
253 if scores.is_empty() {
254 Ok(0.0)
255 } else {
256 Ok(scores.iter().sum::<f64>() / scores.len() as f64)
257 }
258 }
259
260 fn calculate_topic_coherence_uci(
262 &self,
263 topic_word_s: &[&String],
264 word_freq: &HashMap<String, usize>,
265 co_occurrence: &HashMap<(String, String), usize>,
266 ) -> Result<f64> {
267 let mut scores = Vec::new();
268
269 for i in 0..topic_word_s.len() {
270 for j in (i + 1)..topic_word_s.len() {
271 let word_1 = topic_word_s[i];
272 let word_2 = topic_word_s[j];
273
274 let freq1 = word_freq.get(word_1).copied().unwrap_or(0) as f64;
275 let freq2 = word_freq.get(word_2).copied().unwrap_or(0) as f64;
276
277 let co_freq = co_occurrence
278 .get(&if word_1 < word_2 {
279 (word_1.clone(), word_2.clone())
280 } else {
281 (word_2.clone(), word_1.clone())
282 })
283 .copied()
284 .unwrap_or(0) as f64;
285
286 if freq1 > 0.0 && freq2 > 0.0 && co_freq > 0.0 {
288 let total = word_freq.values().sum::<usize>() as f64;
289 let pmi = (co_freq * total / (freq1 * freq2)).ln();
290 scores.push(pmi);
291 }
292 }
293 }
294
295 if scores.is_empty() {
296 Ok(0.0)
297 } else {
298 Ok(scores.iter().sum::<f64>() / scores.len() as f64)
299 }
300 }
301
302 fn build_co_occurrence_matrix(
304 &self,
305 documents: &[Vec<String>],
306 ) -> Result<(DocFreqMap, CoDocFreqMap)> {
307 let mut word_freq: HashMap<String, usize> = HashMap::new();
308 let mut co_occurrence: HashMap<(String, String), usize> = HashMap::new();
309
310 for doc in documents {
311 for word_ in doc {
313 *word_freq.entry(word_.clone()).or_insert(0) += 1;
314 }
315
316 for i in 0..doc.len() {
318 let window_end = (i + self.window_size).min(doc.len());
319
320 for j in (i + 1)..window_end {
321 let word_1 = &doc[i];
322 let word_2 = &doc[j];
323
324 if word_1 != word_2 {
325 let key = if word_1 < word_2 {
326 (word_1.clone(), word_2.clone())
327 } else {
328 (word_2.clone(), word_1.clone())
329 };
330 *co_occurrence.entry(key).or_insert(0) += 1;
331 }
332 }
333 }
334 }
335
336 Ok((word_freq, co_occurrence))
337 }
338
339 fn calculate_npmi(&self, freq1: f64, freq2: f64, co_freq: f64, ntotal: f64) -> f64 {
341 if freq1 == 0.0 || freq2 == 0.0 || co_freq == 0.0 {
342 return -1.0;
343 }
344
345 let p1 = freq1 / ntotal;
346 let p2 = freq2 / ntotal;
347 let p12 = co_freq / ntotal;
348
349 let pmi = (p12 / (p1 * p2)).ln();
350 let npmi = pmi / -(p12.ln());
351
352 npmi.clamp(-1.0, 1.0)
353 }
354}
355
356pub struct TopicDiversity;
358
359impl TopicDiversity {
360 pub fn calculate(topics: &[Topic]) -> f64 {
362 let mut all_word_s = Vec::new();
363 let mut unique_word_s = HashSet::new();
364
365 for topic in topics {
366 for (word_, _) in &topic.top_words {
367 all_word_s.push(word_.clone());
368 unique_word_s.insert(word_.clone());
369 }
370 }
371
372 if all_word_s.is_empty() {
373 return 0.0;
374 }
375
376 unique_word_s.len() as f64 / all_word_s.len() as f64
377 }
378
379 pub fn pairwise_distances(topics: &[Topic]) -> Array2<f64> {
381 let ntopics = topics.len();
382 let mut distances = Array2::zeros((ntopics, ntopics));
383
384 for i in 0..ntopics {
385 for j in 0..ntopics {
386 if i == j {
387 distances[[i, j]] = 0.0;
388 } else {
389 let word_s_i: HashSet<String> = topics[i]
390 .top_words
391 .iter()
392 .map(|(word, _)| word.clone())
393 .collect();
394 let word_s_j: HashSet<String> = topics[j]
395 .top_words
396 .iter()
397 .map(|(word, _)| word.clone())
398 .collect();
399
400 let intersection = word_s_i.intersection(&word_s_j).count();
401 let union = word_s_i.union(&word_s_j).count();
402
403 distances[[i, j]] = 1.0 - (intersection as f64 / union as f64);
404 }
405 }
406 }
407
408 distances
409 }
410}
411
412#[cfg(test)]
413mod tests {
414 use super::*;
415
416 fn create_testtopics() -> Vec<Topic> {
417 vec![
418 Topic {
419 id: 0,
420 top_words: vec![
421 ("machine".to_string(), 0.1),
422 ("learning".to_string(), 0.09),
423 ("algorithm".to_string(), 0.08),
424 ],
425 coherence: None,
426 },
427 Topic {
428 id: 1,
429 top_words: vec![
430 ("neural".to_string(), 0.12),
431 ("network".to_string(), 0.11),
432 ("deep".to_string(), 0.10),
433 ],
434 coherence: None,
435 },
436 ]
437 }
438
439 fn create_test_documents() -> Vec<Vec<String>> {
440 vec![
441 vec!["machine", "learning", "algorithm", "data"]
442 .into_iter()
443 .map(String::from)
444 .collect(),
445 vec!["neural", "network", "deep", "learning"]
446 .into_iter()
447 .map(String::from)
448 .collect(),
449 vec!["machine", "algorithm", "neural", "network"]
450 .into_iter()
451 .map(String::from)
452 .collect(),
453 vec!["deep", "learning", "machine", "data"]
454 .into_iter()
455 .map(String::from)
456 .collect(),
457 ]
458 }
459
460 #[test]
461 fn test_cv_coherence() {
462 let coherence = TopicCoherence::new();
463 let topics = create_testtopics();
464 let documents = create_test_documents();
465
466 let score = coherence
467 .cv_coherence(&topics, &documents)
468 .expect("Operation failed");
469 assert!((-1.0..=1.0).contains(&score));
470 }
471
472 #[test]
473 fn test_umass_coherence() {
474 let coherence = TopicCoherence::new();
475 let topics = create_testtopics();
476 let documents = create_test_documents();
477
478 let score = coherence
479 .umass_coherence(&topics, &documents)
480 .expect("Operation failed");
481 assert!(score.is_finite());
482 }
483
484 #[test]
485 fn test_uci_coherence() {
486 let coherence = TopicCoherence::new();
487 let topics = create_testtopics();
488 let documents = create_test_documents();
489
490 let score = coherence
491 .uci_coherence(&topics, &documents)
492 .expect("Operation failed");
493 assert!(score.is_finite());
494 }
495
496 #[test]
497 fn test_topic_diversity() {
498 let topics = create_testtopics();
499 let diversity = TopicDiversity::calculate(&topics);
500
501 assert!((0.0..=1.0).contains(&diversity));
502 assert_eq!(diversity, 1.0);
504 }
505
506 #[test]
507 fn test_pairwise_distances() {
508 let topics = create_testtopics();
509 let distances = TopicDiversity::pairwise_distances(&topics);
510
511 assert_eq!(distances[[0, 0]], 0.0);
513 assert_eq!(distances[[1, 1]], 0.0);
514
515 assert_eq!(distances[[0, 1]], 1.0);
517 assert_eq!(distances[[1, 0]], 1.0);
518 }
519
520 #[test]
521 fn test_emptytopics() {
522 let coherence = TopicCoherence::new();
523 let topics: Vec<Topic> = vec![];
524 let documents = create_test_documents();
525
526 let cv_score = coherence
527 .cv_coherence(&topics, &documents)
528 .expect("Operation failed");
529 assert!(cv_score.is_nan() || cv_score == 0.0);
530
531 let diversity = TopicDiversity::calculate(&topics);
532 assert_eq!(diversity, 0.0);
533 }
534}