topic_coherence_demo/
topic_coherence_demo.rs1use scirs2_text::{
4 LatentDirichletAllocation, LdaTopic, Tokenizer, TopicCoherence, TopicDiversity,
5 WhitespaceTokenizer,
6};
7use std::collections::HashMap;
8
9#[allow(dead_code)]
10fn main() -> Result<(), Box<dyn std::error::Error>> {
11 println!("Topic Coherence Evaluation Demo");
12 println!("==============================\n");
13
14 let documents = vec![
16 "Machine learning algorithms are used in artificial intelligence",
17 "Deep learning neural networks process complex data patterns",
18 "Natural language processing enables text understanding",
19 "Computer vision algorithms detect objects in images",
20 "Reinforcement learning agents learn through trial and error",
21 "Supervised learning requires labeled training data",
22 "Unsupervised learning discovers hidden patterns",
23 "Transfer learning reuses pretrained models",
24 "Statistical models analyze numerical data distributions",
25 "Regression analysis predicts continuous outcomes",
26 "Classification algorithms categorize data points",
27 "Time series analysis forecasts temporal patterns",
28 "Clustering groups similar data together",
29 "Feature engineering improves model performance",
30 "Model validation prevents overfitting",
31 ];
32
33 let tokenizer = WhitespaceTokenizer::new();
35 let tokenized_docs: Vec<Vec<String>> = documents
36 .iter()
37 .map(|doc| tokenizer.tokenize(doc).expect("Operation failed"))
38 .collect();
39
40 let mut vocabulary = HashMap::new();
42 let mut word_id = 0;
43
44 for doc in &tokenized_docs {
45 for word in doc {
46 if !vocabulary.contains_key(word) {
47 vocabulary.insert(word.clone(), word_id);
48 word_id += 1;
49 }
50 }
51 }
52
53 let n_docs = tokenized_docs.len();
55 let n_words = vocabulary.len();
56 let mut doc_term_matrix = scirs2_core::ndarray::Array2::zeros((n_docs, n_words));
57
58 for (doc_idx, doc) in tokenized_docs.iter().enumerate() {
59 for word in doc {
60 if let Some(&word_id) = vocabulary.get(word) {
61 doc_term_matrix[[doc_idx, word_id]] += 1.0;
62 }
63 }
64 }
65
66 println!("1. Training LDA Model");
68 println!("--------------------");
69
70 let mut lda = LatentDirichletAllocation::with_ntopics(3);
71 lda.fit(&doc_term_matrix)?;
72
73 let id_to_word: HashMap<usize, String> = vocabulary
75 .iter()
76 .map(|(word, &id)| (id, word.clone()))
77 .collect();
78
79 let topics = lda.get_topics(5, &id_to_word)?;
81
82 println!("Discovered topics:");
83 for (i, topic) in topics.iter().enumerate() {
84 println!("\nTopic {}: ", i + 1);
85 for (word, prob) in &topic.top_words {
86 println!(" {word} ({prob:.4})");
87 }
88 }
89
90 println!("\n2. Topic Coherence Metrics");
92 println!("-------------------------");
93
94 let coherence_calc = TopicCoherence::new().with_window_size(5);
95
96 let cv_coherence = coherence_calc.cv_coherence(&topics, &tokenized_docs)?;
98 println!("C_v coherence: {cv_coherence:.4}");
99
100 let umass_coherence = coherence_calc.umass_coherence(&topics, &tokenized_docs)?;
102 println!("UMass coherence: {umass_coherence:.4}");
103
104 let uci_coherence = coherence_calc.uci_coherence(&topics, &tokenized_docs)?;
106 println!("UCI coherence: {uci_coherence:.4}");
107
108 println!("\n3. Topic Diversity");
110 println!("-----------------");
111
112 let diversity = TopicDiversity::calculate(&topics);
113 println!("Topic diversity: {diversity:.4}");
114
115 let distances = TopicDiversity::pairwise_distances(&topics);
117 println!("\nPairwise Jaccard distances between topics:");
118 for i in 0..distances.nrows() {
119 for j in 0..distances.ncols() {
120 print!("{:.3} ", distances[[i, j]]);
121 }
122 println!();
123 }
124
125 println!("\n4. Optimal Topic Number Analysis");
127 println!("-------------------------------");
128
129 let topic_counts = vec![2, 3, 4, 5];
130 let mut results = Vec::new();
131
132 for n_topics in topic_counts {
133 let mut lda = LatentDirichletAllocation::with_ntopics(n_topics);
134 lda.fit(&doc_term_matrix)?;
135
136 let topics = lda.get_topics(5, &id_to_word)?;
137 let coherence = coherence_calc.cv_coherence(&topics, &tokenized_docs)?;
138 let diversity = TopicDiversity::calculate(&topics);
139
140 results.push((n_topics, coherence, diversity));
141 println!("{n_topics} topics: coherence={coherence:.4}, diversity={diversity:.4}");
142 }
143
144 let optimal = results
146 .iter()
147 .max_by(|a, b| {
148 let score_a = a.1 + 0.5 * a.2;
150 let score_b = b.1 + 0.5 * b.2;
151 score_a.partial_cmp(&score_b).expect("Operation failed")
152 })
153 .expect("Operation failed");
154
155 println!(
156 "\nOptimal number of topics: {} (coherence={:.4}, diversity={:.4})",
157 optimal.0, optimal.1, optimal.2
158 );
159
160 println!("\n5. Manual Topic Evaluation");
162 println!("-------------------------");
163
164 let manual_topics = vec![
165 LdaTopic {
166 id: 0,
167 top_words: vec![
168 ("learning".to_string(), 0.15),
169 ("machine".to_string(), 0.12),
170 ("algorithm".to_string(), 0.10),
171 ("data".to_string(), 0.08),
172 ("model".to_string(), 0.07),
173 ],
174 coherence: None,
175 },
176 LdaTopic {
177 id: 1,
178 top_words: vec![
179 ("network".to_string(), 0.14),
180 ("neural".to_string(), 0.13),
181 ("deep".to_string(), 0.11),
182 ("layer".to_string(), 0.09),
183 ("process".to_string(), 0.08),
184 ],
185 coherence: None,
186 },
187 ];
188
189 let manual_coherence = coherence_calc.cv_coherence(&manual_topics, &tokenized_docs)?;
190 let manual_diversity = TopicDiversity::calculate(&manual_topics);
191
192 println!("Manual topics coherence: {manual_coherence:.4}");
193 println!("Manual topics diversity: {manual_diversity:.4}");
194
195 Ok(())
196}