scirs2_text/
summarize_advanced.rs1use std::collections::HashMap;
8
9use crate::error::{Result, TextError};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum SentenceSimilarity {
18 CosineTFIDF,
20 BM25,
22 Jaccard,
24}
25
26fn split_sentences(text: &str) -> Vec<String> {
32 let mut sentences = Vec::new();
34 let mut current = String::new();
35 let chars: Vec<char> = text.chars().collect();
36 let n = chars.len();
37 let mut i = 0;
38 while i < n {
39 current.push(chars[i]);
40 if matches!(chars[i], '.' | '?' | '!') {
41 if i + 1 >= n || chars[i + 1].is_whitespace() {
43 let trimmed = current.trim().to_string();
44 if !trimmed.is_empty() {
45 sentences.push(trimmed);
46 }
47 current = String::new();
48 }
49 }
50 i += 1;
51 }
52 let trimmed = current.trim().to_string();
53 if !trimmed.is_empty() {
54 sentences.push(trimmed);
55 }
56 sentences
57}
58
59fn tokenize_words(sentence: &str) -> Vec<String> {
61 sentence
62 .split(|c: char| !c.is_alphabetic())
63 .filter(|s| !s.is_empty())
64 .map(|s| s.to_lowercase())
65 .collect()
66}
67
68fn jaccard(a: &str, b: &str) -> f64 {
74 let ta: std::collections::HashSet<String> = tokenize_words(a).into_iter().collect();
75 let tb: std::collections::HashSet<String> = tokenize_words(b).into_iter().collect();
76 let inter = ta.intersection(&tb).count();
77 let union = ta.union(&tb).count();
78 if union == 0 {
79 0.0
80 } else {
81 inter as f64 / union as f64
82 }
83}
84
85fn tf_map(sentence: &str) -> HashMap<String, f64> {
87 let words = tokenize_words(sentence);
88 let n = words.len() as f64;
89 if n == 0.0 {
90 return HashMap::new();
91 }
92 let mut counts: HashMap<String, usize> = HashMap::new();
93 for w in words {
94 *counts.entry(w).or_insert(0) += 1;
95 }
96 counts.into_iter().map(|(k, c)| (k, c as f64 / n)).collect()
97}
98
99fn cosine_tfidf(a: &str, b: &str, idf: &HashMap<String, f64>) -> f64 {
101 let ta = tf_map(a);
102 let tb = tf_map(b);
103 let dot: f64 = ta
104 .iter()
105 .filter_map(|(w, &tfa)| {
106 tb.get(w).map(|&tfb| {
107 let idf_w = idf.get(w).copied().unwrap_or(1.0);
108 tfa * idf_w * tfb * idf_w
109 })
110 })
111 .sum();
112 let norm_a: f64 = ta
113 .values()
114 .map(|&v| {
115 let idf_w = idf.get(&String::new()).copied().unwrap_or(1.0);
116 (v * idf_w).powi(2)
117 })
118 .sum::<f64>()
119 .sqrt();
120 let norm_b: f64 = tb
121 .values()
122 .map(|&v| {
123 let idf_w = idf.get(&String::new()).copied().unwrap_or(1.0);
124 (v * idf_w).powi(2)
125 })
126 .sum::<f64>()
127 .sqrt();
128 if norm_a == 0.0 || norm_b == 0.0 {
129 return 0.0;
130 }
131 dot / (norm_a * norm_b)
132}
133
134fn bm25_similarity(
136 query: &str,
137 doc: &str,
138 avgdl: f64,
139 idf: &HashMap<String, f64>,
140 k1: f64,
141 b: f64,
142) -> f64 {
143 let query_words = tokenize_words(query);
144 let doc_words = tokenize_words(doc);
145 let dl = doc_words.len() as f64;
146 let mut freq_map: HashMap<&str, usize> = HashMap::new();
147 for w in &doc_words {
148 *freq_map.entry(w.as_str()).or_insert(0) += 1;
149 }
150 query_words
151 .iter()
152 .map(|w| {
153 let idf_w = idf.get(w).copied().unwrap_or(1.0);
154 let f = *freq_map.get(w.as_str()).unwrap_or(&0) as f64;
155 idf_w * (f * (k1 + 1.0)) / (f + k1 * (1.0 - b + b * dl / avgdl))
156 })
157 .sum()
158}
159
160fn build_idf(sentences: &[String]) -> HashMap<String, f64> {
162 let n = sentences.len() as f64;
163 let mut df: HashMap<String, usize> = HashMap::new();
164 for sent in sentences {
165 let words: std::collections::HashSet<String> = tokenize_words(sent).into_iter().collect();
166 for w in words {
167 *df.entry(w).or_insert(0) += 1;
168 }
169 }
170 df.into_iter()
171 .map(|(w, c)| (w, ((n + 1.0) / (c as f64 + 1.0)).ln() + 1.0))
172 .collect()
173}
174
175#[derive(Debug, Clone)]
184pub struct TextRankSummarizer {
185 pub damping: f64,
187 pub n_iterations: usize,
189 pub similarity: SentenceSimilarity,
191}
192
193impl Default for TextRankSummarizer {
194 fn default() -> Self {
195 TextRankSummarizer {
196 damping: 0.85,
197 n_iterations: 50,
198 similarity: SentenceSimilarity::CosineTFIDF,
199 }
200 }
201}
202
203impl TextRankSummarizer {
204 pub fn new(damping: f64, n_iterations: usize, similarity: SentenceSimilarity) -> Self {
206 TextRankSummarizer {
207 damping,
208 n_iterations,
209 similarity,
210 }
211 }
212
213 pub fn summarize(&self, text: &str, n_sentences: usize) -> Result<String> {
217 let sentences = split_sentences(text);
218 if sentences.is_empty() {
219 return Ok(String::new());
220 }
221 let k = n_sentences.min(sentences.len());
222 if k == sentences.len() {
223 return Ok(text.to_string());
224 }
225
226 let idf = build_idf(&sentences);
227 let n = sentences.len();
228 let avgdl = sentences
229 .iter()
230 .map(|s| tokenize_words(s).len())
231 .sum::<usize>() as f64
232 / n as f64;
233
234 let mut adj = vec![vec![0.0f64; n]; n];
236 for i in 0..n {
237 for j in 0..n {
238 if i == j {
239 continue;
240 }
241 adj[i][j] = match self.similarity {
242 SentenceSimilarity::Jaccard => jaccard(&sentences[i], &sentences[j]),
243 SentenceSimilarity::CosineTFIDF => {
244 cosine_tfidf(&sentences[i], &sentences[j], &idf)
245 }
246 SentenceSimilarity::BM25 => {
247 bm25_similarity(&sentences[i], &sentences[j], avgdl, &idf, 1.5, 0.75)
248 }
249 };
250 }
251 }
252
253 for row in adj.iter_mut() {
255 let total: f64 = row.iter().sum();
256 if total > 0.0 {
257 for v in row.iter_mut() {
258 *v /= total;
259 }
260 }
261 }
262
263 let mut scores = vec![1.0 / n as f64; n];
265 for _ in 0..self.n_iterations {
266 let mut new_scores = vec![0.0f64; n];
267 for j in 0..n {
268 for i in 0..n {
269 new_scores[j] += adj[i][j] * scores[i];
270 }
271 new_scores[j] = (1.0 - self.damping) / n as f64 + self.damping * new_scores[j];
272 }
273 scores = new_scores;
274 }
275
276 let mut ranked: Vec<(usize, f64)> = scores.iter().cloned().enumerate().collect();
278 ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
279 let mut top_indices: Vec<usize> = ranked.iter().take(k).map(|&(i, _)| i).collect();
280 top_indices.sort();
282
283 let summary = top_indices
284 .iter()
285 .map(|&i| sentences[i].as_str())
286 .collect::<Vec<_>>()
287 .join(" ");
288 Ok(summary)
289 }
290}
291
292pub struct ExtractiveSummarizer;
298
299impl ExtractiveSummarizer {
300 pub fn lead_k(text: &str, k: usize) -> Result<String> {
302 if k == 0 {
303 return Err(TextError::InvalidInput("k must be at least 1".to_string()));
304 }
305 let sentences = split_sentences(text);
306 let selected: Vec<&str> = sentences.iter().take(k).map(String::as_str).collect();
307 Ok(selected.join(" "))
308 }
309
310 pub fn frequency_based(text: &str, k: usize) -> Result<String> {
315 if k == 0 {
316 return Err(TextError::InvalidInput("k must be at least 1".to_string()));
317 }
318 let sentences = split_sentences(text);
319 if sentences.is_empty() {
320 return Ok(String::new());
321 }
322
323 let mut freq: HashMap<String, usize> = HashMap::new();
325 for sent in &sentences {
326 for w in tokenize_words(sent) {
327 *freq.entry(w).or_insert(0) += 1;
328 }
329 }
330 let max_freq = *freq.values().max().unwrap_or(&1) as f64;
332 let norm_freq: HashMap<String, f64> = freq
333 .into_iter()
334 .map(|(k, v)| (k, v as f64 / max_freq))
335 .collect();
336
337 let mut scored: Vec<(usize, f64)> = sentences
339 .iter()
340 .enumerate()
341 .map(|(i, sent)| {
342 let words = tokenize_words(sent);
343 let score: f64 = words
344 .iter()
345 .map(|w| norm_freq.get(w).copied().unwrap_or(0.0))
346 .sum();
347 (i, score)
348 })
349 .collect();
350
351 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
353 let mut top_indices: Vec<usize> = scored.iter().take(k).map(|&(i, _)| i).collect();
354 top_indices.sort();
355
356 let result = top_indices
357 .iter()
358 .map(|&i| sentences[i].as_str())
359 .collect::<Vec<_>>()
360 .join(" ");
361 Ok(result)
362 }
363}
364
365#[cfg(test)]
366mod tests {
367 use super::*;
368
369 const TEXT: &str = "The quick brown fox jumps over the lazy dog. \
370 A fox is a cunning animal. \
371 Dogs are loyal companions. \
372 Foxes live in dens and are mostly nocturnal. \
373 The dog slept all afternoon.";
374
375 #[test]
376 fn test_split_sentences() {
377 let sents = split_sentences(TEXT);
378 assert_eq!(sents.len(), 5);
379 }
380
381 #[test]
382 fn test_textrank_summarize_count() {
383 let summarizer = TextRankSummarizer::default();
384 let summary = summarizer.summarize(TEXT, 2).expect("summarize failed");
385 let count = summary.matches('.').count();
387 assert!(count <= 2, "too many sentences: {}", count);
388 }
389
390 #[test]
391 fn test_textrank_empty_text() {
392 let summarizer = TextRankSummarizer::default();
393 let summary = summarizer.summarize("", 3).expect("summarize empty");
394 assert!(summary.is_empty());
395 }
396
397 #[test]
398 fn test_textrank_more_than_available() {
399 let summarizer = TextRankSummarizer::default();
400 let summary = summarizer.summarize(TEXT, 100).expect("summarize");
402 assert!(!summary.is_empty());
403 }
404
405 #[test]
406 fn test_textrank_bm25() {
407 let summarizer = TextRankSummarizer::new(0.85, 20, SentenceSimilarity::BM25);
408 let summary = summarizer.summarize(TEXT, 2).expect("summarize bm25");
409 assert!(!summary.is_empty());
410 }
411
412 #[test]
413 fn test_textrank_jaccard() {
414 let summarizer = TextRankSummarizer::new(0.85, 20, SentenceSimilarity::Jaccard);
415 let summary = summarizer.summarize(TEXT, 2).expect("summarize jaccard");
416 assert!(!summary.is_empty());
417 }
418
419 #[test]
420 fn test_lead_k() {
421 let summary = ExtractiveSummarizer::lead_k(TEXT, 2).expect("lead_k");
422 assert!(summary.starts_with("The quick"));
424 }
425
426 #[test]
427 fn test_lead_k_zero_error() {
428 let result = ExtractiveSummarizer::lead_k(TEXT, 0);
429 assert!(result.is_err());
430 }
431
432 #[test]
433 fn test_frequency_based() {
434 let summary = ExtractiveSummarizer::frequency_based(TEXT, 2).expect("freq_based");
435 assert!(!summary.is_empty());
436 }
437
438 #[test]
439 fn test_jaccard_similarity() {
440 let a = "the cat sat on the mat";
441 let b = "the cat sat on the mat";
442 let sim = jaccard(a, b);
443 assert!((sim - 1.0).abs() < 1e-6);
444 }
445
446 #[test]
447 fn test_jaccard_no_overlap() {
448 let a = "hello world";
449 let b = "foo bar baz";
450 let sim = jaccard(a, b);
451 assert!(sim < 0.01);
452 }
453
454 #[test]
455 fn test_build_idf() {
456 let sents = vec!["the cat sat".to_string(), "the dog ran".to_string()];
457 let idf = build_idf(&sents);
458 let idf_the = idf.get("the").copied().unwrap_or(0.0);
461 let idf_cat = idf.get("cat").copied().unwrap_or(0.0);
462 assert!(idf_cat >= idf_the);
463 }
464}